# include/net/tensor_contract/tensor_contract_qbb.hpp
# Namespaces
Name |
---|
net |
# Source code
#ifndef NET_TENSOR_CONTRACT_QUICKBB_HPP
#define NET_TENSOR_CONTRACT_QUICKBB_HPP
#include "../network.hpp"
#include "../rational.hpp"
#include "../tensor_tools.hpp"
#include "../traits.hpp"
#include "../group.hpp"
#include "tensor_contract_engine.hpp"
#include <TAT/TAT.hpp>
#include <cstdlib>
#include <functional>
#include <random>
#include <variant>
#include <memory>
#include <vector>
#include <iostream>
namespace net {
// find neighbor with least contraction count of a node within given part
// EdgeVal = int
template <typename IterNode, typename NodeSet, typename EdgeKey, typename Trait>
std::pair<std::pair<int,int>, IterNode> search_quick(IterNode & it1, const NodeSet & part) {
std::map<EdgeKey, std::pair<std::pair<int,int>, IterNode>, typename Trait::edgekey_less> legs;
for (auto & e : it1->second.edges) {
if (part.count(e.second.nbkey) > 0) {
if (legs.count(e.second.nbkey) == 0)
legs[e.second.nbkey] =
{{std::get<1>(it1->second.val) * std::get<1>(e.second.nbitr->second.val) / e.second.val,
std::get<1>(it1->second.val) * std::get<1>(e.second.nbitr->second.val) / e.second.val / e.second.val},
e.second.nbitr};
else{
legs[e.second.nbkey].first.first /= e.second.val;
legs[e.second.nbkey].first.second /= (e.second.val*e.second.val);
}
}
}
auto min_itr = std::min_element(
legs.begin(),
legs.end(),
[](const std::pair<EdgeKey, std::pair<std::pair<int,int>, IterNode>> & a, const std::pair<EdgeKey, std::pair<std::pair<int,int>, IterNode>> & b) {
return a.second.first < b.second.first;
});
return min_itr->second;
}
template <typename contract_type, typename absorb_type, typename NodeVal, typename NodeKey, typename EdgeKey, typename Trait>
NodeKey Engine::contract_quickbb(
network<NodeVal, int, NodeKey, EdgeKey, Trait> & lat,
const std::set<NodeKey, typename Trait::nodekey_less> & part,
const absorb_type & absorb_fun,
const contract_type & contract_fun) {
// if(contract_test_mode)
// lat.draw("lat before quickbb",{part},true);
auto temp=lat;
//std::cout<<lat.size()<<' '<<temp.size();
std::set<NodeKey, typename Trait::nodekey_less> contracted_sites;
if (contract_trace_mode)
std::cout << "in_quickbb \n";
if (part.size() == 1){
return *(part.begin());
}
// else if(part.size() == 2){
// lat.absorb(*(part.begin()), *(part.begin()++), absorb_fun, contract_fun);
// return *(part.begin());
// }
// std::cout << "<test_part \n";
// for(auto & p:part) std::cout<<p<<'\n';
// std::cout << "test_part> \n";
// lat.draw("part",{part},true);
std::pair<int,int> count;
std::pair<int,int> least_count = {-1,-1};
using NodeItrType = typename network<NodeVal, int, NodeKey, EdgeKey, Trait>::IterNode;
using KeySet = std::set<NodeKey, typename Trait::nodekey_less>;
NodeItrType least_contract1, least_contract2, nb_itr;
for (auto & p : part) {
auto site_it = lat.find(p);
std::get<1>(site_it->second.val) = calc_weight(site_it, lat, KeySet());
}
for (auto & p : part) {
auto site_it = lat.find(p);
std::tie(count, nb_itr) = search_quick<NodeItrType, KeySet, EdgeKey, Trait>(site_it, part);
if (least_count < std::make_pair(0,0) || count < least_count) {
least_count = count;
least_contract1 = site_it;
least_contract2 = nb_itr;
}
}
int contract_size;
auto root_itr = least_contract1;
lat.absorb(root_itr, least_contract2, absorb_fun, contract_fun);
contracted_sites.insert(root_itr->first);
contracted_sites.insert(least_contract2->first);
//temp.draw("quickbb",{contracted_sites},false);
contract_size = 2;
while (contract_size < part.size()) {
std::tie(count, nb_itr) = search_quick<NodeItrType, KeySet, EdgeKey, Trait>(root_itr, part);
lat.absorb(root_itr, nb_itr, absorb_fun, contract_fun);
contracted_sites.insert(nb_itr->first);
//temp.draw("quickbb",{contracted_sites},false);
++contract_size;
}
if (contract_trace_mode)
std::cout << "out_quickbb \n";
return root_itr->first;
}
} // namespace net
#endif
Updated on 15 June 2022 at 16:04:19 CST