Source code for ttfemesh.mesh.domain_mesh

from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple

import numpy as np

from ttfemesh.basis import BilinearBasis, TensorProductBasis
from ttfemesh.domain import Domain, Quad
from ttfemesh.domain.subdomain_connection import CurveConnection2D, VertexConnection2D
from ttfemesh.mesh.mesh_utils import side_concatenation_tt, vertex_concatenation_tt
from ttfemesh.mesh.subdomain_mesh import QuadMesh, SubdomainMesh, SubdomainMesh2D
from ttfemesh.quadrature.quadrature import QuadratureRule
from ttfemesh.tt_tools.tensor_cross import TTCrossConfig
from ttfemesh.types import BoundarySide2D, BoundaryVertex2D, TensorTrain


[docs] class DomainMesh(ABC): """ DomainMesh base class that ties together the domain, subdomain meshes, and basis functions. It provides an interface for the element jacobians of the subdomains, the element to global index maps, the boundary masks and the concatenation maps. """
[docs] def __init__( self, domain: Domain, quadrature_rule: QuadratureRule, mesh_size_exponent: int, basis: TensorProductBasis, tt_cross_config: Optional[TTCrossConfig] = None, ): """ Initialize a DomainMesh. Args: domain (Domain): The domain containing subdomains and their connections. quadrature_rule (QuadratureRule): Quadrature rule for integration. mesh_size_exponent (int): Discretization size exponent. basis (TensorProductBasis): The basis functions for the domain. cross_config (Optional[TTCrossConfig]): Optional configuration for tensor cross approximation. If None, the default configuration is used. """ self.quadrature_rule = quadrature_rule self.mesh_size_exponent = mesh_size_exponent self.domain = domain self.basis = basis self._tt_cross_config = tt_cross_config self.subdomain_meshes = self._create_subdomain_meshes()
@abstractmethod def _create_subdomain_meshes(self): # pragma: no cover num_subdomains = self.domain.num_subdomains subdomain_meshes = [] for i in range(num_subdomains): subdomain = self.domain.get_subdomain(i) mesh_size_exponent = self.mesh_size_exponent quadrature_rule = self.quadrature_rule tt_cross_config = self._tt_cross_config subdomain_mesh = SubdomainMesh( subdomain, quadrature_rule, mesh_size_exponent, tt_cross_config ) subdomain_meshes.append(subdomain_mesh) return subdomain_meshes
[docs] def get_subdomain_mesh(self, subdomain_index: int) -> SubdomainMesh: # noqa """ Get the SubdomainMesh for a subdomain. Args: subdomain_index (int): The index of the subdomain. Returns: SubdomainMesh: The SubdomainMesh for the subdomain. Raises: ValueError: If the subdomain index is invalid. """ self._validate_subdomain_index(subdomain_index) return self.subdomain_meshes[subdomain_index]
[docs] def get_element2global_index_map(self) -> np.ndarray: # noqa """ Get the TT-representation of transformations mapping from element index to global basis function index for all reference basis functions on a reference element in a subdomain. This map depends on the type of basis functions used and the discretization size of the subdomain. Args: subdomain_index (int): The index of the subdomain. Returns: np.ndarray: A matrix of TT-representations, i.e., each element of the matrix is a TT-vector. Indexing of the matrix corresponds to the reference basis function indexing. For example, for a bilinear basis in 2D, the indexing is: (i, j) where i and j are the indices of the basis functions in x and y direction Specifically, (0, 0), (1, 0), (0, 1) and (1, 1), representing the four basis functions corresponding to the four vertices of the reference element: lower left, lower right, upper right and upper left, respectively. See also the documentation for the chosen Basis class. """ mesh_size_exponent = self.mesh_size_exponent ttmaps = self.basis.get_all_element2global_ttmaps(mesh_size_exponent) return ttmaps
[docs] def get_dirichlet_masks(self) -> Dict[int, TensorTrain]: # noqa """ Get the dirichlet boundary masks. Returns: Dict[TensorTrain]: A dictionary where the keys are subdomain indices, and the values are TT-representations of the boundary masks. """ boundary_condition = self.domain.boundary_condition if boundary_condition is None: print("No boundary condition specified.") return None grouped = boundary_condition.group_by_subdomain() boundary_masks = {} for subdomain_index, curve_indices in grouped.items(): sides = [BoundarySide2D(i) for i in curve_indices] boundary_mask = self.basis.get_dirichlet_mask(self.mesh_size_exponent, *sides) boundary_masks[subdomain_index] = boundary_mask return boundary_masks
[docs] @abstractmethod # noqa # pragma: no cover def get_concatenation_maps(self) -> Dict[Tuple[int, int], TensorTrain]: """ Get the TT-representations of the concatenation maps for all pairs of connected subdomains. Returns: Dict[Tuple[int, int], TensorTrain]: A dictionary where the keys are pairs of subdomain indices, and the values are TT-representations of the concatenation maps. """ pass
def _validate_subdomain_index(self, subdomain_index: int): if subdomain_index < 0 or subdomain_index >= self.domain.num_subdomains: raise ValueError( f"Invalid subdomain index: {subdomain_index}. " f"Valid indices are in the range [0, {self.domain.num_subdomains})." ) def __repr__(self) -> str: return ( f"DomainMesh(domain={self.domain}, " f"mesh_size_exponent={self.mesh_size_exponent}, " f"quadrature_rule={self.quadrature_rule}, " f"basis={self.basis})" )
[docs] class DomainMesh2D(DomainMesh): """Mesh for 2D domains. This is essentially a factory for SubdomainMesh2D objects.""" def _create_subdomain_meshes(self): num_subdomains = self.domain.num_subdomains subdomain_meshes = [] for i in range(num_subdomains): subdomain = self.domain.get_subdomain(i) mesh_size_exponent = self.mesh_size_exponent quadrature_rule = self.quadrature_rule tt_cross_config = self._tt_cross_config if isinstance(subdomain, Quad): subdomain_mesh = QuadMesh( quad=subdomain, quadrature_rule=quadrature_rule, mesh_size_exponent=mesh_size_exponent, tt_cross_config=tt_cross_config, ) else: subdomain_mesh = SubdomainMesh2D( subdomain=subdomain, quadrature_rule=quadrature_rule, mesh_size_exponent=mesh_size_exponent, tt_cross_config=tt_cross_config, ) subdomain_meshes.append(subdomain_mesh) return subdomain_meshes def __repr__(self) -> str: return ( f"DomainMesh2D(domain={self.domain}, " f"mesh_size_exponent={self.mesh_size_exponent}, " f"quadrature_rule={self.quadrature_rule}, " f"basis={self.basis})" )
[docs] class DomainBilinearMesh2D(DomainMesh2D): """ Mesh for 2D domains with bilinear basis functions. This implementation of the concatenation maps works only for bilinear basis functions. Example: >>> from ttfemesh.domain import RectangleFactory, CurveConnection2D, VertexConnection2D >>> from ttfemesh.domain import DirichletBoundary2D, Domain2D >>> from ttfemesh.quadrature import GaussLegendre2D >>> from ttfemesh.mesh import DomainBilinearMesh2D >>> lower_left = (0, 0) >>> upper_right = (2, 1) >>> rectangle1 = RectangleFactory.create(lower_left, upper_right) >>> lower_left = (2, 0) >>> upper_right = (3, 1) >>> rectangle2 = RectangleFactory.create(lower_left, upper_right) >>> lower_left = (-2, 1) >>> upper_right = (0, 2) >>> rectangle3 = RectangleFactory.create(lower_left, upper_right) >>> domain_idxs = [0, 1] >>> curve_idxs = [1, 3] >>> edge = CurveConnection2D(domain_idxs, curve_idxs) >>> vertex_idxs = [(0, 3, "start"), (2, 0, "end")] >>> vertex = VertexConnection2D(vertex_idxs) >>> bc = DirichletBoundary2D([(1, 1), (2, 3)]) >>> domain = Domain2D([rectangle1, rectangle2, rectangle3], [edge, vertex], bc) >>> quadrature_rule = GaussLegendre2D() >>> mesh_size_exponent = 3 >>> domain_mesh = DomainBilinearMesh2D(domain, quadrature_rule, mesh_size_exponent) >>> print(domain_mesh) """
[docs] def __init__( self, domain: Domain, quadrature_rule: QuadratureRule, mesh_size_exponent: int, tt_cross_config: Optional[TTCrossConfig] = None, ): basis = BilinearBasis() super().__init__(domain, quadrature_rule, mesh_size_exponent, basis, tt_cross_config)
[docs] def get_concatenation_maps( # noqa self, ) -> Dict[Tuple[int, int], Tuple[TensorTrain, TensorTrain, TensorTrain]]: """ Get the TT-representations of the concatenation maps for all pairs of connected subdomains. See Section 5 of arXiv:1802.02839 for details. Pmp describes which nodes in the left domain are connected to which nodes in the right domain, Pmm describes which nodes in the left domain are to be connected, Ppp describes which nodes in the right domain are to be connected. Returns: Dict[Tuple[int, int], Tuple[TensorTrain, TensorTrain, TensorTrain]]: A dictionary where the keys are pairs of subdomain indices, and the values are tuples TT-representations of the connectivity maps. Raises: ValueError: If the connection type is not supported. """ connections = self.domain.get_connections() concatenation_maps = {} for connection in connections: if isinstance(connection, VertexConnection2D): for ( (subdidx1, subdidx2), (curveidx1, curveidx2), (pos1, pos2), ) in connection.get_connection_pairs(): offset1 = 0 if pos1 == "start" else 1 vertex_idx1 = (curveidx1 + offset1) % 4 offset2 = 0 if pos2 == "start" else 1 vertex_idx2 = (curveidx2 + offset2) % 4 vertex1 = BoundaryVertex2D(vertex_idx1) vertex2 = BoundaryVertex2D(vertex_idx2) tt_connectivity = vertex_concatenation_tt( vertex1, vertex2, self.mesh_size_exponent ) concatenation_maps[(subdidx1, subdidx2)] = tt_connectivity elif isinstance(connection, CurveConnection2D): subdidx1, subdidx2 = connection.subdomains_indices curveidx1, curveidx2 = connection.curve_indices side1 = BoundarySide2D(curveidx1) side2 = BoundarySide2D(curveidx2) tt_connectivity = side_concatenation_tt(side1, side2, self.mesh_size_exponent) concatenation_maps[(subdidx1, subdidx2)] = tt_connectivity else: raise ValueError(f"Unsupported connection type: {type(connection)}.") return concatenation_maps
def __repr__(self): return ( f"DomainBilinearMesh2D(domain={self.domain}, " f"mesh_size_exponent={self.mesh_size_exponent}, " f"quadrature_rule={self.quadrature_rule}, " f"basis={self.basis})" )