# include/net/tensor_contract/tensor_contract_exact.hpp

# Namespaces

Name
net

# Classes

Name
struct net::all_combination

# Source code

#ifndef NET_TENSOR_CONTRACT_EXACT_HPP
#define NET_TENSOR_CONTRACT_EXACT_HPP
#include "../network.hpp"
#include "../rational.hpp"
#include "../tensor_tools.hpp"
#include "../traits.hpp"
#include "../group.hpp"
#include "tensor_contract_engine.hpp"
#include <TAT/TAT.hpp>
#include <cstdlib>
#include <functional>
#include <random>
#include <variant>
#include <memory>
#include <vector>
#include <iostream>

namespace net {

   struct all_combination{
      std::vector<bool> elems;
      all_combination(unsigned int s):elems(std::vector<bool>(s,false)),size(s){};
      unsigned int size;
      bool finish;
      bool reversed=false;
      void begin(unsigned int s){
         if(s>=0 && s<=size){
            finish = false;
            if(s>size/2){
               reversed=true;
               for(int i=0;i<size;++i)
                  elems[i]=(i<size-s);
            }else{
               reversed=false;
               for(int i=0;i<size;++i)
                  elems[i]=(i<s);
            }
         }else
            finish=true;
      }
      void next(){
         int true_num=0;
         int i;
         for(i=0;i<size-1 && !(elems[i] && ! elems[i+1]);++i){
            if(elems[i]) true_num++;
         }
         //std::cout<<i<<' '<<size<<'\n';
         if(i==size-1)
            finish=true;
         else{
            if(true_num<i)
               for(int j=0;j<i;j++)       // true for 0 ... true_num-1
                  elems[j]=(j<true_num); // false for true_num ... i-1
            elems[i]=false;
            elems[i+1]=true;
         }

      }
      template<typename ValSet>
      std::pair<ValSet,ValSet> generate(const ValSet & V){
         std::pair<ValSet,ValSet> result;
         int i=0;
         if(reversed){
            for(auto & v:V)
               elems[i++] ? result.second.insert(v): result.first.insert(v);
         }else{
            for(auto & v:V)
               elems[i++] ? result.first.insert(v): result.second.insert(v);
         }
         return result;
      }
   };

   template <typename contract_type, typename absorb_type, typename NodeVal, typename NodeKey, typename EdgeKey, typename Trait>
   NodeKey Engine::contract_breadth_first(
         network<NodeVal, int, NodeKey, EdgeKey, Trait> & lat,
         const std::set<NodeKey, typename Trait::nodekey_less> & part,
         const absorb_type & absorb_fun,
         const contract_type & contract_fun) {

      if(part.size()==2){
         auto it=part.begin();
         auto it2=part.begin();
         ++it2;
         lat.absorb(*it,*it2,absorb_fun,contract_fun);
         return *it;
      }

      std::map<std::set<NodeKey, typename Trait::nodekey_less>,group<NodeVal, int, NodeKey, EdgeKey, Trait>> solution;
      
      for(auto & p:part){
         //std::cout<<p<<"==="<<std::endl;
      }
      for(auto & p:part){
         //std::cout<<p<<"---"<<std::endl;
         auto x =std::set<NodeKey, typename Trait::nodekey_less>({p});
         auto y=group<NodeVal, int, NodeKey, EdgeKey, Trait>(lat);
         y.absorb(p,absorb_fun,contract_fun);
         std::get<0>(y.val)=std::make_shared<typename std::tuple_element_t<0,decltype(y.val)>::element_type>(tree<typename std::tuple_element_t<0,decltype(y.val)>::element_type::DataType>(std::get<0>(y.val)->val.forget_history(p)));
         solution.insert({x,y});
      }

      all_combination ac(part.size());
      for(int s=2; s<=part.size();++s){
         std::cout<<s<<' '<<part.size()<<"bfirst\n";
         for(ac.begin(s);!ac.finish;ac.next()){
            auto [P,N] = ac.generate(part); // divide part into P and N, P.size()=s
         // std::cout<<"P [ ";
         // for(auto & p: P) std::cout<<p<<' ';
         // std::cout<<"]\n";
         // std::cout<<"N [ ";
         // for(auto & p: N) std::cout<<p<<' ';
         // std::cout<<"]\n";
            all_combination ac2(P.size());
            for(int s2=1;s2<=P.size()/2;++s2){
               for(ac2.begin(s2);!ac2.finish;ac2.next()){
                  auto [L,R] = ac2.generate(P);  // divide part into P and N, L.size()=s2 <= R.size()
         // std::cout<<"L [ ";
         // for(auto & p: L) std::cout<<p<<' ';
         // std::cout<<"]\n";
         // std::cout<<"R [ ";
         // for(auto & p: R) std::cout<<p<<' ';
         // std::cout<<"]\n";
                  if(P.size()%2==0 && s2==P.size()/2 && L>R) // if L.size() == R.size(), require L <= R
                     continue;
                  group<NodeVal, int, NodeKey, EdgeKey, Trait> G=net::contract<NodeVal, int, NodeKey, EdgeKey, Trait>(solution[L],solution[R],absorb_fun,contract_fun);
                  auto it=solution.find(P);
                  if(it==solution.end()){
                     solution.insert({P,G});
                  }else{
                     //std::cout<<std::get<0>(G.val)->val.contraction_cost<<'\n';
                     //std::cout<<std::get<0>(it->second.val)->val.contraction_cost<<'\n';
                     if(std::get<0>(G.val)->val.contraction_cost < std::get<0>(it->second.val)->val.contraction_cost){
                        it->second=G;
                     }
                  }
               }
            }
         }
      }
      //std::get<0>(solution[part].val)->draw();
      NodeKey res=lat.absorb_tree(std::get<0>(solution[part].val),absorb_fun,contract_fun);
      //std::cout<<res<<"=result="<<std::endl;
      return res;
   }


   void test_all_combination(){

      net::all_combination ac(8);
      std::set<int> eight;
      for(int i=0;i<8;++i) eight.insert(i);
      for(int s=0; s<=8;++s){
         std::cout<<"break in to "<<s<<" and "<<(8-s)<<'\n';
         int j=0;
         for(ac.begin(s);!ac.finish;ac.next()){
            auto [P,N] = ac.generate(eight);
            std::cout<<j++<<"\n [ ";
            for(auto & p: P) std::cout<<p<<' ';
            std::cout<<"]\n [ ";
            for(auto & n: N) std::cout<<n<<' ';
            std::cout<<"]\n";
         }
      }

   }

} // namespace net
#endif

Updated on 15 June 2022 at 16:04:19 CST