# include/net/tensor_contract/tensor_contract_tools.hpp
# Namespaces
Name |
---|
net |
# Classes
Name | |
---|---|
struct | net::lift_contract |
struct | net::lift_absorb |
struct | net::keyset |
struct | net::contract_info |
struct | net::contract_info2 |
# Source code
#ifndef NET_TENSOR_CONTRACT_TOOLS_HPP
#define NET_TENSOR_CONTRACT_TOOLS_HPP
#include "../network.hpp"
#include "../rational.hpp"
#include "../tensor_tools.hpp"
#include "../traits.hpp"
#include "../group.hpp"
#include <TAT/TAT.hpp>
#include <cstdlib>
#include <functional>
#include <random>
#include <variant>
#include <memory>
#include <vector>
#include <iostream>
namespace net {
template <typename contract_type>
struct lift_contract {
contract_type contract_fun;
lift_contract(contract_type cf) : contract_fun(cf){};
template <typename NodeVal, typename NoUse>
NodeVal operator()(const NodeVal & ten1, const NodeVal & ten2, const NoUse & inds) const {
return std::make_tuple(contract_fun(std::get<0>(ten1), std::get<0>(ten2), inds), std::get<1>(ten1), std::get<2>(ten1));
}
};
template <typename absorb_type>
struct lift_absorb {
absorb_type absorb_fun;
lift_absorb(absorb_type af) : absorb_fun(af){};
template <typename NodeVal, typename EdgeVal, typename NoUse>
NodeVal operator()(const NodeVal & ten1, const EdgeVal & eg, const NoUse & ind) const {
return std::make_tuple(absorb_fun(std::get<0>(ten1), eg, ind), std::get<1>(ten1), std::get<2>(ten1));
}
};
template <typename NodeVal, typename NodeKey, typename EdgeKey, typename Trait>
void get_component(
node<NodeVal, int, NodeKey, EdgeKey, Trait> & n,
const NodeKey & p,
const std::set<NodeKey, typename Trait::nodekey_less> & part,
std::set<NodeKey, typename Trait::nodekey_less> & treated,
std::set<NodeKey, typename Trait::nodekey_less> & component) {
treated.insert(p);
component.insert(p);
for (auto & eg : n.edges) //???
if (part.count(eg.second.nbkey) > 0 && treated.count(eg.second.nbkey) == 0 && std::get<1>(n.val) == std::get<1>(eg.second.nbitr->second.val))
get_component(eg.second.nbitr->second, eg.second.nbkey, part, treated, component);
}
template <typename NodeVal, typename NodeKey, typename EdgeKey, typename Trait>
std::vector<std::set<NodeKey, typename Trait::nodekey_less>> disconnect(
network<NodeVal, int, NodeKey, EdgeKey, Trait> & lat,
const std::set<NodeKey, typename Trait::nodekey_less> & part,
const std::vector<std::set<NodeKey, typename Trait::nodekey_less>> & subparts) {
for (int i = 0; i < subparts.size(); ++i)
for (auto & s : subparts[i])
std::get<1>(lat[s].val) = i;
std::set<NodeKey, typename Trait::nodekey_less> treated, newsubpart;
std::vector<std::set<NodeKey, typename Trait::nodekey_less>> newsubparts;
while (treated.size() < part.size())
for (auto & p : part)
if (treated.count(p) == 0) {
newsubparts.push_back({});
get_component(lat[p], p, part, treated, newsubparts.back());
}
return newsubparts;
}
template <typename KeySetType>
struct keyset {
KeySetType node_set;
static keyset<KeySetType> absorb(const keyset<KeySetType> & a, const int & b) {
return a;
};
static keyset<KeySetType> contract(const keyset<KeySetType> & a, const keyset<KeySetType> & b) {
keyset<KeySetType> r = a;
r.node_set.insert(b.node_set.begin(), b.node_set.end());
return r;
}
keyset() = default;
template <typename NodeType>
keyset(const typename KeySetType::key_type & k, const NodeType & n) {
node_set.insert(k);
}
std::string show()const{
return "";
}
keyset<KeySetType> forget_history(const typename KeySetType::key_type & k) {
keyset<KeySetType> ci;
ci.node_set.insert(k);
return ci;
}
};
template <typename KeySetType>
struct contract_info {
KeySetType node_set;
long long int this_weight = 1; // record the size of the tensor
long long int hist_max_weight = 1; // maximal space cost
long long int contraction_cost = 1; // total time cost
long long int legs = 1; // record the total dimension of common legs
static contract_info<KeySetType> absorb(contract_info<KeySetType> & c, const int & d) {
contract_info<KeySetType> r = c;
r.legs *= d;
return r;
};
static contract_info<KeySetType> contract(contract_info<KeySetType> & c, contract_info<KeySetType> & d) {
contract_info<KeySetType> r;
r.legs = 1;
r.node_set = c.node_set;
r.node_set.insert(d.node_set.begin(), d.node_set.end());
r.this_weight = c.this_weight / c.legs / d.legs * d.this_weight / c.legs / d.legs;
r.contraction_cost = c.contraction_cost + d.contraction_cost + c.this_weight / c.legs / d.legs * d.this_weight;
r.hist_max_weight = std::max(std::max(c.this_weight, d.this_weight), r.this_weight);
c.legs=1;
d.legs=1;
return r;
}
std::string show()const{
std::ostringstream os;
os.precision(4);
os<<std::fixed<<"C "<<contraction_cost<<"\nW "<<this_weight<<"\nH "<<hist_max_weight;
return os.str();
}
contract_info() = default;
template <typename NodeType>
contract_info(const typename KeySetType::key_type & k, const NodeType & n) {
node_set.insert(k);
this_weight = net::tensor::get_size(n);
hist_max_weight = this_weight;
}
contract_info<KeySetType> forget_history(const typename KeySetType::key_type & k) {
contract_info<KeySetType> ci;
ci.node_set.insert(k);
ci.this_weight = this_weight;
ci.hist_max_weight = hist_max_weight;
ci.contraction_cost = contraction_cost;
ci.legs = legs;
return ci;
}
};
constexpr double exp_sum_log(const double & a,
const double & b) { // return log(exp(a)+exp(b))
double ratio = 0.;
if (a > b + 10) {
ratio = std::pow(10,b - a);
return a + ratio - ratio * ratio / 2. + ratio * ratio * ratio / 3.;
} else if (b > a + 10) {
ratio = std::pow(10,a - b);
return b + ratio - ratio * ratio / 2. + ratio * ratio * ratio / 3.;
} else
return b + std::log10(1 + std::pow(10,a - b));
}
// a logged version of contract_info
template <typename KeySetType>
struct contract_info2 {
KeySetType node_set;
double this_weight = 0.;
double hist_max_weight = 0.;
double contraction_cost = 0.;
double legs = 0.;
static contract_info2<KeySetType> absorb(contract_info2<KeySetType> & c, const int & d) {
contract_info2<KeySetType> r = c;
r.legs += std::log10(double(d));
return r;
}
static contract_info2<KeySetType> contract(contract_info2<KeySetType> & c, contract_info2<KeySetType> & d) {
contract_info2<KeySetType> r;
r.legs = 0.;
r.node_set = c.node_set;
// std::cout<<"before\n";
// for (auto & test: d.node_set){
// std::cout<<'\n';
// std::cout<<test<<'\n';
// }
r.node_set.insert(d.node_set.begin(), d.node_set.end());
//std::cout<<"after\n";
r.this_weight = c.this_weight + d.this_weight - 2 * c.legs - 2 * d.legs;
r.contraction_cost = exp_sum_log(exp_sum_log(c.contraction_cost, d.contraction_cost), c.this_weight - c.legs - d.legs + d.this_weight);
r.hist_max_weight = std::max(std::max(c.this_weight, d.this_weight), r.this_weight);
c.legs=0.;
d.legs=0.;
return r;
}
std::string show()const{
std::ostringstream os;
os.precision(4);
os<<std::fixed<<"C "<<contraction_cost<<"\nW "<<this_weight<<"\nH "<<hist_max_weight;
return os.str();
}
contract_info2() = default;
template <typename NodeType>
contract_info2(const typename KeySetType::key_type & k, const NodeType & n) {
node_set.insert(k);
this_weight = std::log10(double(net::tensor::get_size(n)));
hist_max_weight = this_weight;
}
contract_info2<KeySetType> forget_history(const typename KeySetType::key_type & k) {
contract_info2<KeySetType> ci;
ci.node_set.insert(k);
ci.this_weight = this_weight;
ci.hist_max_weight = hist_max_weight;
ci.contraction_cost = contraction_cost;
ci.legs = legs;
return ci;
}
};
} // namespace net
#endif
Updated on 15 June 2022 at 16:04:19 CST