# include/net/tensor_tools.hpp

# Namespaces

Name
net
net::tensor

# Source code

#ifndef NET_TENSOR_TOOLS_HPP
#define NET_TENSOR_TOOLS_HPP

// #include "network.hpp"
#include "traits.hpp"
#include <TAT/TAT.hpp>
#include <functional>
#include <iostream>
#include <random>
#include <variant>

namespace net {
   namespace tensor {
      template <typename T, typename EdgeKey = stdEdgeKey>
      using Tensor = TAT::Tensor<T, TAT::NoSymmetry, EdgeKey>;

      template <typename TnType>
      int get_dim(TnType ten, int s) {
         return ten.core->edges[s].get_dimension_from_symmetry(TAT::NoSymmetry());
      }
      template <typename TnType>
      int get_size(TnType ten) {
         return ten.storage().size();
      }
      template <typename TnType>
      int get_rank(TnType ten) {
         return ten.names.size();
      }

      template <typename T>
      std::set<std::string> inds_start_with(const TAT::Tensor<T> & ten, const std::string & head) {
         std::set<std::string> res;
         // std::cout<<"test inds_start_with\n";
         // std::cout<<head<<'\n';
         for (auto & s : ten.names)
            if (s.compare(0, head.size(), head) == 0) {
               // std::cout<<s<<'\n';
               res.insert(s);
            }
         return res;
      }

      template <typename TnType, typename EdgeKey>
      int get_dim(TnType ten, EdgeKey s) {

            for (auto i = 0; i < ten.get_rank(); i++) 
               if(s==ten.names[i])
               return ten.core->edges[i].get_dimension_from_symmetry(TAT::NoSymmetry());
         return 0;
      }
      template <typename TnType>
      void diminfo(TnType ten, std::ostream & os) {

            for (auto i = 0; i < ten.get_rank(); i++) 
            os << ten.names[i] << ' ' << ten.core->edges[i].get_dimension_from_symmetry(TAT::NoSymmetry())<< ' ';
         os << '\n';
      }

      template <typename T, typename EdgeKey = stdEdgeKey>
      Tensor<T, EdgeKey> get_diag(const Tensor<T, EdgeKey> & ten, std::vector<std::pair<EdgeKey, EdgeKey>> names) {
         std::vector<EdgeKey> oldname;
         for (auto & n : names)
            oldname.push_back(n.first);
         for (auto & n : names)
            oldname.push_back(n.second);
         auto ten2 = ten.transpose(oldname);
         std::vector<EdgeKey> newname;
         std::vector<TAT::Edge<TAT::NoSymmetry>>  newdims;
         int dim, newsize = 1;
         for (int i = 0; i < names.size(); ++i) {
            newname.push_back(oldname[i]);
            dim = get_dim(ten2, i);
            newdims.push_back(dim);
            newsize *= dim;
         }
         Tensor<T, EdgeKey> result(newname, newdims);
         for (int i = 0; i < newsize; ++i) {
            result.storage()[i] = ten2.storage()[i * (newsize + 1)];
         }
         return result;
      }

   } // namespace tensor
} // namespace net

#endif

Updated on 15 June 2022 at 16:04:19 CST