from itertools import chain
from typing import List, Sequence, Dict, Any
import numpy as np
from print_tree import print_tree
from renormalizer import Op
from renormalizer.model.basis import BasisSet, BasisDummy
from renormalizer.tn.node import NodeUnion, TreeNodeBasis, copy_connection
class Tree:
def __init__(self, root: NodeUnion):
assert root.parent is None
self.root = root
self.node_list = self.preorder_list()
self.node_idx = {node: i for i, node in enumerate(self.node_list)}
def preorder_list(self, func=None) -> List[NodeUnion]:
def recursion(node: NodeUnion):
if func is None:
ret = [node]
else:
ret = [func(node)]
if not node.children:
return ret
for child in node.children:
ret += recursion(child)
return ret
return recursion(self.root)
def postorder_list(self) -> List[NodeUnion]:
def recursion(node: NodeUnion):
if not node.children:
return [node]
ret = []
for child in node.children:
ret += recursion(child)
ret.append(node)
return ret
return recursion(self.root)
@staticmethod
def find_path(node1: NodeUnion, node2: NodeUnion) -> List[NodeUnion]:
"""Find the path from node1 to node2. Not most efficient but simple to implement"""
assert node1 != node2
ancestors1 = node1.ancestors
ancestors2 = node2.ancestors
ancestors2_set = set(ancestors2)
common_ancestors = [ancestor for ancestor in ancestors1 if ancestor in ancestors2_set]
common_ancestor = common_ancestors[0]
path1 = ancestors1[:ancestors1.index(common_ancestor) + 1]
path2 = ancestors2[:ancestors2.index(common_ancestor)]
return path1 + path2[::-1]
@property
def size(self):
return len(self.node_list)
def __len__(self):
return self.size
def __iter__(self):
return iter(self.node_list)
def __repr__(self):
return f"{self.__class__} with {len(self)} nodes"
[docs]class BasisTree(Tree):
"""
Tree of basis sets. The tree nodes are :class:`TreeNodeBasis`.
Parameters
----------
root: :class:`TreeNodeBasis`
The root of the tree
"""
[docs] @classmethod
def linear(cls, basis_list: List[BasisSet]):
"""
Generate a linear tree, i.e, MPS.
Parameters
----------
basis_list: list of ``BasisSet``
The basis set list.
Returns
-------
The constructed basis tree.
"""
node_list = [TreeNodeBasis([basis]) for basis in basis_list]
for i in range(len(node_list) - 1):
node_list[i].add_child(node_list[i + 1])
return cls(node_list[0])
[docs] @classmethod
def binary(cls, basis_list: List[BasisSet]):
"""
Generate a binary tree.
Parameters
----------
basis_list: list of ``BasisSet``
The basis set list.
Returns
-------
The constructed basis tree.
"""
node_list = [TreeNodeBasis([basis]) for basis in basis_list]
def binary_recursion(node: TreeNodeBasis, offspring: List[TreeNodeBasis]):
if len(offspring) == 0:
return
node.add_child(offspring[0])
if len(offspring) == 1:
return
node.add_child(offspring[1])
new_offspring = offspring[2:]
mid_idx = len(new_offspring) // 2
binary_recursion(offspring[0], new_offspring[:mid_idx])
binary_recursion(offspring[1], new_offspring[mid_idx:])
binary_recursion(node_list[0], node_list[1:])
return cls(node_list[0])
[docs] @classmethod
def general_mctdh(
cls,
basis_list: List[BasisSet],
tree_order: int,
contract_primitive: bool = False,
contract_label: Sequence[bool] = None,
dummy_label="MCTDH virtual",
):
r"""
MCTDH tree with the specified tree order.
The feature of this type of tree is that all physical degrees of freedom are attached to the leaf nodes.
Also, each leaf node typically has more than one physical degrees of freedom.
Parameters
----------
basis_list: list of :class:`~renormalizer.model.basis.BasisSet`
The list of basis sets for the system.
tree_order: int
Tree order. For example, 2 means binary tree and 3 means ternary tree.
contract_primitive: bool
Whether contract the primitive basis. Defaults to False.
If set to True, each primitive basis in ``basis_list`` will be contracted before attached
to the tree. The following is a schematic view, where ``o`` represents a node
and ``d`` means the physical bond.
.. code-block::
# contract primitive
o
/ \
o o
d -> | | <- d
If set to False, the following type of tree will be constructed.
.. code-block::
# not contract primitive
o
d -> / \ <- d
d means physical bond
contract_label: list of bool
If ``contract_primitive`` is set to True,
this list determines which primitive basis should be contracted.
dummy_label:
The label for the virtual nodes in MCTDH.
Returns
-------
The constructed basis tree.
See Also
--------
binary_mctdh: construct binary MCTDH tree (tree order is set to 2).
ternary_mctdh: construct ternary MCTDH tree (tree order is set to 3).
"""
# o
# d -> / \ <- d
# d means physical bond
# `contract_label` decides whether we do contraction for a particular basis
assert len(basis_list) > 1
# prepare elementary nodes
elementary_nodes: List[TreeNodeBasis] = []
if not contract_primitive:
assert contract_label is None, "providing label makes sense only when primitives are contracted"
while tree_order < len(basis_list):
node = TreeNodeBasis(basis_list[:tree_order])
elementary_nodes.append(node)
basis_list = basis_list[tree_order:]
elementary_nodes.append(TreeNodeBasis(basis_list))
else:
if contract_label is None:
for basis in basis_list:
node1 = TreeNodeBasis([basis])
elementary_nodes.append(node1)
else:
assert len(contract_label) == len(basis_list)
i = 0
while i != len(basis_list):
if contract_label[i]:
elementary_nodes.append(TreeNodeBasis([basis_list[i]]))
i += 1
else:
for j in range(1, tree_order + 1):
if i + j == len(contract_label) or contract_label[i + j]:
break
elementary_nodes.append(TreeNodeBasis(basis_list[i : i + j]))
i += j
# recursive tree construction
def recursion(elementary_nodes_: List[TreeNodeBasis]) -> TreeNodeBasis:
nonlocal dummy_i
node = TreeNodeBasis([BasisDummy((dummy_label, dummy_i))])
dummy_i += 1
if len(elementary_nodes_) <= tree_order:
node.add_child(elementary_nodes_)
return node
for group in approximate_partition(elementary_nodes_, tree_order):
node.add_child(recursion(group))
return node
dummy_i = 0
root = recursion(elementary_nodes)
return cls(root)
[docs] @classmethod
def binary_mctdh(
cls, basis_list: List[BasisSet], contract_primitive=False, contract_label=None, dummy_label="MCTDH virtual"
):
"""
Construct binary MCTDH tree.
See Also
--------
general_mctdh: construct MCTDH tree with any order.
"""
return cls.general_mctdh(basis_list, 2, contract_primitive, contract_label, dummy_label)
[docs] @classmethod
def ternary_mctdh(
cls, basis_list: List[BasisSet], contract_primitive=False, contract_label=None, dummy_label="MCTDH virtual"
):
"""
Construct ternary MCTDH tree.
See Also
--------
general_mctdh: construct MCTDH tree with any order.
"""
return cls.general_mctdh(basis_list, 3, contract_primitive, contract_label, dummy_label)
[docs] @classmethod
def t3ns(cls, basis_list: List[BasisSet], t3ns_label="T3NS virtual"):
def recursion(parent, basis_list_: List[BasisSet]):
nonlocal dummy_i
if len(basis_list_) == 0:
return
if len(basis_list_) == 1:
parent.add_child(TreeNodeBasis(basis_list_))
return
if len(basis_list_) == 2:
node1 = TreeNodeBasis(basis_list_[:1])
parent.add_child(node1)
node2 = TreeNodeBasis(basis_list_[1:])
node1.add_child(node2)
return
node1 = TreeNodeBasis(basis_list_[:1])
parent.add_child(node1)
node2 = TreeNodeBasis([BasisDummy((t3ns_label, dummy_i))])
dummy_i += 1
node1.add_child(node2)
for partition_ in approximate_partition(basis_list_[1:], 2):
recursion(node2, partition_)
dummy_i = 0
root = TreeNodeBasis([BasisDummy((t3ns_label, dummy_i))])
dummy_i += 1
for partition in approximate_partition(basis_list, 3):
recursion(root, partition)
return cls(root)
def __init__(self, root: TreeNodeBasis):
super().__init__(root)
for node in self.node_list:
assert isinstance(node, TreeNodeBasis)
qn_size_list = [n.qn_size for n in self.node_list]
if len(set(qn_size_list)) != 1:
raise ValueError(f"Inconsistent quantum number size: {set(qn_size_list)}")
self.qn_size: int = qn_size_list[0]
# map basis to node index
self.basis2idx: Dict[BasisSet, int] = {}
# map dof to node index
self.dof2idx: Dict[Any, int] = {}
# map dof to basis
self.dof2basis: Dict[Any, BasisSet] = {}
for i, node in enumerate(self.node_list):
for b in node.basis_sets:
self.basis2idx[b] = i
for d in b.dofs:
self.dof2idx[d] = i
self.dof2basis[d] = b
# identity operator
self.identity_op: Op = Op("I", self.root.dofs[0][0])
# identity ttno
self.identity_ttno = None
# dummy ttno. Same tree topology but only has dummy basis
# used as a dummy operator for calculating norm, etc
self.dummy_ttno = None
[docs] def print(self, print_function=None):
class print_tn_basis(print_tree):
def get_children(self, node):
return node.children
def get_node_str(self, node):
return str([b.dofs for b in node.basis_sets])
tree = print_tn_basis(self.root)
if print_function is not None:
for row in tree.rows:
print_function(row)
@property
def basis_list(self) -> List[BasisSet]:
return list(chain(*[n.basis_sets for n in self.node_list]))
@property
def dof_list(self) -> List[Any]:
return list(chain(*[b.dofs for b in self.basis_list]))
@property
def basis_list_postorder(self) -> List[BasisSet]:
return list(chain(*[n.basis_sets for n in self.postorder_list()]))
@property
def pbond_dims(self) -> List[List[int]]:
return [b.pbond_dims for b in self.node_list]
[docs] def add_auxiliary_space(self, auxiliary_label="Q") -> "BasisTree":
# make a new basis tree with auxiliary basis
node2_list = []
for node in self:
basis_set2_list = []
for basis in node.basis_sets:
# the P space
basis_set2_list.append(basis)
if not isinstance(basis, BasisDummy):
# the Q space
basis_q: BasisSet = basis.copy((auxiliary_label, basis.dofs))
# set to zero for know. could change to more complicated case in the future
basis_q.sigmaqn = np.zeros_like(basis.sigmaqn)
basis_set2_list.append(basis_q)
node2_list.append(TreeNodeBasis(basis_set2_list))
copy_connection(self.node_list, node2_list)
basis_tree2 = BasisTree(node2_list[0])
return basis_tree2
def approximate_partition(sequence, ngroups):
size = (len(sequence) - 1) // ngroups + 1
ret = []
for i in range(ngroups):
start = i * size
end = min((i + 1) * size, len(sequence))
ret.append(sequence[start:end])
return ret