# include/net/tensor_contract/tensor_contract_divide_naive.hpp
# Namespaces
Name |
---|
net |
# Classes
Name | |
---|---|
struct | net::kset_contract |
# Source code
#ifndef NET_TENSOR_CONTRACT_DIVIDE_NAIVE_HPP
#define NET_TENSOR_CONTRACT_DIVIDE_NAIVE_HPP
#include "../network.hpp"
#include "../rational.hpp"
#include "../tensor_tools.hpp"
#include "../traits.hpp"
#include "../group.hpp"
#include "tensor_contract_engine.hpp"
#include <TAT/TAT.hpp>
#include <cstdlib>
#include <functional>
#include <random>
#include <variant>
#include <memory>
#include <vector>
#include <iostream>
namespace net {
struct kset_contract {
template <typename NodeVal, typename EdgeKey, typename Comp>
NodeVal operator()(const NodeVal & g1, const NodeVal & g2, const std::set<std::pair<EdgeKey, EdgeKey>, Comp> & inds) const {
NodeVal res = g1;
res.insert(g2.begin(), g2.end());
return res;
}
};
template <typename NodeVal, typename NodeKey, typename EdgeKey, typename Trait>
std::vector<std::set<NodeKey, typename Trait::nodekey_less>>
Engine::divide(network<NodeVal, int, NodeKey, EdgeKey, Trait> & lat, const std::set<NodeKey, typename Trait::nodekey_less> & part) {
if (contract_trace_mode)
std::cout << "in_divide \n";
using KeySet = std::set<NodeKey, typename Trait::nodekey_less>;
KeySet coarse_part = part;
network<KeySet, int, NodeKey, EdgeKey> fakelat;
fakelat = lat.template fmap<decltype(fakelat)>([](const NodeVal & tp) { return KeySet(); }, [](const int & m) { return m; });
for (auto & n : fakelat)
n.second.val.insert(n.first);
//粗粒化
// for(auto & i: std::get<0>(lat["ten0_1"])->val ) std::cout<<i<<"
// **************1\n"; std::cout<<"finish **************1\n";
if (contract_trace_mode)
std::cout << "coarse_grain \n";
combine_edges(fakelat, part);
while (coarse_part.size() > coarse_grain_to) {
std::set<std::tuple<int, NodeKey, NodeKey>, std::greater<std::tuple<int, NodeKey, NodeKey>>> ordered_bond; // weight, from, to
// we may replace std::greater
// add bonds to ordered_bond
for (auto & p : coarse_part) {
auto & this_site = fakelat[p];
for (auto & e : this_site.edges) {
if (coarse_part.count(e.second.nbkey) > 0) {
ordered_bond.insert({e.second.val, p, e.second.nbkey});
}
}
}
// do coarse grain
KeySet treated_sites;
for (auto & b : ordered_bond) {
if (treated_sites.count(std::get<1>(b)) == 0 && treated_sites.count(std::get<2>(b)) == 0) {
treated_sites.insert(std::get<1>(b));
treated_sites.insert(std::get<2>(b));
fakelat.absorb(std::get<1>(b), std::get<2>(b), no_absorb(), kset_contract());
coarse_part.erase(std::get<2>(b));
combine_edges(fakelat, {std::get<1>(b)});
}
}
}
// lat.draw(true);
// for(auto & i: std::get<0>(lat["ten0_1"])->val ) std::cout<<i<<"
// **************2\n"; std::cout<<"finish **************2\n";
if (contract_trace_mode)
std::cout << "initial \n";
//初始分割,分成cut_part份
KeySet treated_sites = {};
KeySet final_sites = {};
int size_limit;
int treated_size = 0;
bool decide2exit;
for (int i = 0; i < cut_part; ++i) {
size_limit = (part.size() - treated_size) / (cut_part - i);
if (contract_test_mode)
fakelat.draw("lat before initial_cut no. " + std::to_string(i), {part, final_sites}, true);
// find site with max weight (within part-treated)
int max_weight = 0;
int this_weight;
typename network<KeySet, int, NodeKey, EdgeKey>::IterNode max_weight_site_itr;
typename network<KeySet, int, NodeKey, EdgeKey>::IterNode this_site_itr;
for (auto p : coarse_part) {
if (treated_sites.count(p) == 0) {
this_site_itr = fakelat.find(p);
this_weight = calc_weight(this_site_itr, coarse_part, treated_sites);
if (contract_test_mode)
std::cout << "test.divide.initial_cut.max_weight " << p << ' ' << this_weight << '\n';
if (this_weight > max_weight) {
max_weight = this_weight;
max_weight_site_itr = this_site_itr;
}
}
}
if (contract_test_mode)
std::cout << "test.divide.initial_cut.max_weight final " << max_weight_site_itr->first << ' ' << max_weight << '\n';
// construct subpart
treated_sites.insert(max_weight_site_itr->first);
final_sites.insert(max_weight_site_itr->first);
while (true) {
// std::cout<<"a \n";
// fakelat.consistency();
combine_edges(fakelat, {max_weight_site_itr->first});
// std::cout<<"b \n";
// fakelat.consistency();
std::set<std::pair<int, NodeKey>, std::greater<std::pair<int, NodeKey>>> ordered_nb;
for (auto & e : max_weight_site_itr->second.edges) {
if (coarse_part.count(e.second.nbkey) > 0 && treated_sites.count(e.second.nbkey) == 0) {
ordered_nb.insert({e.second.val, e.second.nbkey});
}
}
if (ordered_nb.size() == 0) // exit1: no neighbors
break;
decide2exit = false;
for (auto & nb : ordered_nb) {
if (max_weight_site_itr->second.val.size() + fakelat[nb.second].val.size() <= size_limit) {
fakelat.absorb(max_weight_site_itr->first, nb.second, no_absorb(), kset_contract());
treated_sites.insert(nb.second);
} else {
decide2exit = true; // exit2: reach limit
break;
}
}
if (decide2exit)
break;
}
treated_size += max_weight_site_itr->second.val.size();
// std::cout<<"c \n";
// fakelat.consistency();
combine_edges(fakelat, {max_weight_site_itr->first});
// std::cout<<"d \n";
// fakelat.consistency();
}
// if(contract_test_mode)
// fakelat.draw("lat after initial_cut",{coarse_part,final_sites},true);
if (contract_trace_mode)
std::cout << "after_initial \n";
//可能出现这种情况:由于连通性的原因,所有subpart搞完后,还剩下一些sites
// int n=0;
while (treated_sites.size() < coarse_part.size()) {
// for(auto & e:fakelat.nodes.at("ten3_4").edges)
// std::cout<<"hedddddre2 "<<e.second.nbkey<<'\n';
// std::cout<<"diff"<<treated_sites.size()<<' '<<coarse_part.size()<<'\n';
for (auto & s : final_sites) {
// std::cout<<"here "<<s<<'\n';
// auto & test_node=fakelat.nodes.at(s);
// std::cout<<"there \n";
// std::cout<<test_node.edges.size()<<"\n";
// std::cout<<"there \n";
for (auto & e : fakelat.at(s).edges) {
// std::cout<<"here2
// "<<e.second.nbkey<<coarse_part.count(e.second.nbkey)<<final_sites.count(e.second.nbkey)<<'\n';
if (coarse_part.count(e.second.nbkey) > 0 && treated_sites.count(e.second.nbkey) == 0) {
treated_sites.insert(e.second.nbkey);
fakelat.absorb(s, e.second.nbkey, no_absorb(), kset_contract());
// std::cout<<treated_sites.size()<<"success\n";
// std::cout<<treated_sites.size()<<"success\n";
break;
}
}
}
// n++;
// if(n==10) std::exit(EXIT_FAILURE);
}
if (contract_test_mode)
fakelat.draw("lat after adjustment", {part, final_sites}, true);
if (contract_trace_mode)
std::cout << "release \n";
// for(auto & i: std::get<0>(lat["ten0_1"])->val ) std::cout<<i<<"
// **************3\n"; std::cout<<"finish **************3\n";
//释放
std::vector<KeySet> subparts;
for (auto & s : final_sites)
subparts.push_back(fakelat[s].val);
// for(auto & i: std::get<0>(lat["ten0_1"])->val ) std::cout<<i<<"
// **************4\n"; std::cout<<"finish **************4\n";
if (contract_trace_mode)
std::cout << "out_divide \n";
// adjust(lat,part,subparts,0.1);
// refine(lat,part,subparts);
adjust(lat, part, subparts, 0.3);
// refine(lat,part,subparts);
// adjust(lat,part,subparts,1);
// refine(lat,part,subparts);
if (contract_test_mode)
lat.draw("lat after adjust", subparts, true);
auto subparts2 = disconnect(lat, part, subparts);
if (contract_test_mode)
lat.draw("lat after disconnect", subparts2, true);
return subparts2;
}
template <typename NodeType, typename SetType>
bool calc_gain(int cut_part, net::rational & max_gain, NodeType & n, int & this_part, int & nb_part, const SetType & part) {
std::vector<int> weights(cut_part, 1);
for (auto & eg : n.edges)
if (part.count(eg.second.nbkey) > 0)
weights[std::get<1>(eg.second.nbitr->second.val)] *= eg.second.val;
this_part = std::get<1>(n.val);
nb_part = this_part;
int this_weight = weights[this_part];
int max_weight = this_weight;
for (int i = 0; i < cut_part; i++) {
if (weights[i] > max_weight) {
max_weight = weights[i];
nb_part = i;
}
}
max_gain = net::rational(max_weight, this_weight);
return (max_weight > this_weight);
}
template <typename NodeVal, typename NodeKey, typename EdgeKey, typename Trait>
double Engine::refine(
network<NodeVal, int, NodeKey, EdgeKey, Trait> & lat,
const std::set<NodeKey, typename Trait::nodekey_less> & part,
std::vector<std::set<NodeKey, typename Trait::nodekey_less>> & subparts) {
double uneven=0.;
// if(contract_trace_mode) std::cout<<"in_refine \n";
// for(auto & i: std::get<0>(lat["ten0_1"])->val ) std::cout<<i<<"
// **************5\n"; std::cout<<"finish **************5\n";
// if(contract_test_mode) lat.draw("lat before refinement",subparts,true);
// set part label and calculate part size
std::vector<int> part_size(subparts.size(), 0);
for (int i = 0; i < subparts.size(); ++i) {
for (auto & s : subparts[i])
std::get<1>(lat[s].val) = i;
part_size[i] = subparts[i].size();
}
// calc gain sorted set
std::map<std::pair<net::rational, NodeKey>, std::pair<int, int>, std::greater<std::pair<net::rational, NodeKey>>> gain_rec; // gain, from, to
net::rational gain;
double tot_gain = 1.;
int this_part, nb_part;
for (auto & p : part) {
auto & n = lat[p];
if (calc_gain(cut_part, gain, n, this_part, nb_part, part)) {
gain_rec[{gain, p}] = {this_part, nb_part};
std::get<2>(n.val) = gain;
}
}
if (contract_test_mode) {
std::cout << "test.refine.build_gain start\n";
for (auto & i : gain_rec)
std::cout << "test.refine.build_gain " << i.first.second << ' ' << i.first.first << ' ' << i.second.first << " -> " << i.second.second
<< "\n";
std::cout << "test.refine.build_gain finish\n";
}
// for(int i=0;i<subparts.size();++i){
// std::cout<<"test.refine.run_gain subpart"<<i<<'\n';
// for(auto & s:subparts[i]) std::cout<<" "<<s<<'\n';
// for(auto & s:subparts[i]) std::cout<<" \n";
// }
int min_size = part.size() / cut_part * (1 - uneven);
int max_size = part.size() / cut_part * (1 + uneven);
for (int i = 0; i < refine_sweep; i++) {
if (gain_rec.size() == 0)
break;
for (auto g_rec = gain_rec.begin(); g_rec != gain_rec.end(); ++g_rec) {
if (part_size[g_rec->second.first] > min_size && part_size[g_rec->second.second] < max_size) {
// std::cout<<"test.refine.run_gain "<<g_rec->first.second<<' '<<
// g_rec->first.first<<' '<<g_rec->second.first<<" ->
// "<<g_rec->second.second<<"\n";
part_size[g_rec->second.first]--;
part_size[g_rec->second.second]++;
auto this_name = g_rec->first.second;
auto & this_node = lat[this_name];
std::get<1>(this_node.val) = g_rec->second.second;
subparts[g_rec->second.first].erase(g_rec->first.second);
subparts[g_rec->second.second].insert(g_rec->first.second);
tot_gain *= g_rec->first.first.to_double();
gain_rec.erase(g_rec);
if (calc_gain(cut_part, gain, this_node, this_part, nb_part, part)) {
gain_rec[{gain, this_name}] = {this_part, nb_part};
std::get<2>(this_node.val) = gain;
}
for (auto & eg : this_node.edges) {
if (part.count(eg.second.nbkey) > 0) {
auto lookup_nb = gain_rec.find({std::get<2>(eg.second.nbitr->second.val), eg.second.nbkey});
if (lookup_nb != gain_rec.end()) {
gain_rec.erase(lookup_nb);
}
if (calc_gain(cut_part, gain, eg.second.nbitr->second, this_part, nb_part, part)) {
gain_rec[{gain, eg.second.nbkey}] = {this_part, nb_part};
std::get<2>(eg.second.nbitr->second.val) = gain;
}
}
}
break;
}
}
}
// for(auto & i: std::get<0>(lat["ten0_1"])->val ) std::cout<<i<<"
// **************6\n"; std::cout<<"finish **************6\n";
// if(contract_test_mode) lat.draw("lat after refinement",subparts,true);
// if(contract_trace_mode) std::cout<<"out_refine \n";
return tot_gain;
}
// calc weight of a node by part no.
template <typename NodeType, typename SetType>
void calc_weight(int cut_part, const NodeType & n, std::vector<int> & weight, const SetType & part) {
weight = std::vector<int>(cut_part, 1);
for (auto & eg : n.edges)
if (part.count(eg.second.nbkey) > 0)
weight[std::get<1>(eg.second.nbitr->second.val)] *= eg.second.val;
}
template <typename NodeVal, typename NodeKey, typename EdgeKey, typename Trait>
void Engine::adjust(
network<NodeVal, int, NodeKey, EdgeKey, Trait> & lat,
const std::set<NodeKey, typename Trait::nodekey_less> & part,
std::vector<std::set<NodeKey, typename Trait::nodekey_less>> & subparts,
double alpha) {
double uneven=0.;
if (contract_trace_mode)
std::cout << "in_adjust \n";
std::vector<std::set<NodeKey, typename Trait::nodekey_less>> best_subparts = subparts, temp_subparts;
std::vector<NodeKey> part_vec = {part.begin(), part.end()};
// for(auto & i: std::get<0>(lat["ten0_1"])->val ) std::cout<<i<<"
// **************5\n"; std::cout<<"finish **************5\n";
if (contract_test_mode)
lat.draw("lat before adjustment", subparts, true);
// set part label and calculate part size
std::vector<int> part_size(subparts.size(), 0);
for (int i = 0; i < subparts.size(); ++i) {
for (auto & s : subparts[i])
std::get<1>(lat[s].val) = i;
part_size[i] = subparts[i].size();
}
// calc gain sorted set
int min_size = part.size() / cut_part * (1 - 0.5 * uneven);
int max_size = part.size() / cut_part * (1 + 0.5 * uneven);
std::uniform_int_distribution<> site_dist(0, part.size() - 1);
std::uniform_int_distribution<> part_dist(0, cut_part - 2);
std::uniform_real_distribution<> monte_carlo;
std::map<NodeKey, std::vector<int>, typename Trait::nodekey_less> weights;
for (auto & p : part)
calc_weight(cut_part, lat[p], weights[p], part);
double cumu_gain = 1, max_gain = 1, further_gain;
int i = 0;
int j = 0;
while (i < refine_sweep && j < 100 * refine_sweep) {
j++;
auto this_site = part_vec[site_dist(rand)];
auto & this_node = lat[this_site];
auto & this_weight = weights[this_site];
int this_part = std::get<1>(this_node.val);
int next_part = part_dist(rand);
if (next_part >= this_part)
next_part++;
// std::cout<<std::pow(double(this_weight[next_part])/double(this_weight[this_part]),alpha)<<'\n';
if (part_size[this_part] > min_size && part_size[next_part] < max_size &&
monte_carlo(rand) < std::pow(double(this_weight[next_part]) / double(this_weight[this_part]), alpha)) {
// std::cout<<"success";
cumu_gain *= double(this_weight[next_part]) / this_weight[this_part];
part_size[this_part]--;
part_size[next_part]++;
std::get<1>(this_node.val) = next_part;
subparts[this_part].erase(this_site);
subparts[next_part].insert(this_site);
calc_weight(cut_part, this_node, this_weight, part);
for (auto & eg : this_node.edges)
calc_weight(cut_part, lat[eg.second.nbkey], weights[eg.second.nbkey], part);
temp_subparts = subparts;
further_gain = refine(lat, part, temp_subparts);
if (cumu_gain * further_gain > max_gain) {
max_gain = cumu_gain * further_gain;
best_subparts = temp_subparts;
}
++i;
}
}
// for(auto & i: std::get<0>(lat["ten0_1"])->val ) std::cout<<i<<"
// **************6\n"; std::cout<<"finish **************6\n";
if (contract_test_mode)
lat.draw("lat after adjustment", best_subparts, true);
if (contract_trace_mode)
std::cout << "out_adjust \n";
subparts = best_subparts;
}
} // namespace net
#endif
Updated on 15 June 2022 at 16:04:19 CST