Source code for ttfemesh.domain.subdomain_connection

from abc import ABC, abstractmethod
from typing import List, Literal, Tuple

import numpy as np

from ttfemesh.domain.curve import Curve
from ttfemesh.domain.subdomain import Subdomain2D

CurvePosition = Literal["start", "end"]


class SubdomainConnection(ABC):
    """Generic subdomain connection class."""

    @abstractmethod
    def validate(self, *args, **kwargs):  # pragma: no cover
        """
        Validates that the connection is consistent with the provided subdomains.
        """
        pass

    @property
    @abstractmethod
    def dimension(self) -> int:  # pragma: no cover
        """Dimension of the subdomain connection."""
        pass

    @property
    @abstractmethod
    def num_connected_subdomains(self) -> int:  # pragma: no cover
        """Number of connected subdomains."""
        pass


class SubdomainConnection2D(SubdomainConnection):
    """Generic 2D subdomain connection class."""

    @property
    def dimension(self) -> int:
        """Dimension of the subdomain connection."""
        return 2


[docs] class VertexConnection2D(SubdomainConnection2D): """ Initialize a 2D vertex connection. The subdomain indexes reference into the list of subdomains passed to the Domain constructor. Args: connection (List[Tuple[int, int, CurvePosition]]): List of subdomains sharing this vertex. Each connection is a tuple of (subdomain index, curve index, position). Curve position is either "start" or "end". Example: >>> from ttfemesh.domain import RectangleFactory, CurveConnection2D, VertexConnection2D >>> from ttfemesh.domain import Domain2D >>> lower_left = (0, 0) >>> upper_right = (2, 1) >>> rectangle1 = RectangleFactory.create(lower_left, upper_right) >>> lower_left = (2, 0) >>> upper_right = (3, 1) >>> rectangle2 = RectangleFactory.create(lower_left, upper_right) >>> lower_left = (-2, 1) >>> upper_right = (0, 2) >>> rectangle3 = RectangleFactory.create(lower_left, upper_right) >>> domain_idxs = [0, 1] >>> curve_idxs = [1, 3] >>> edge = CurveConnection2D(domain_idxs, curve_idxs) >>> vertex_idxs = [(0, 3, "start"), (2, 0, "end")] >>> vertex = VertexConnection2D(vertex_idxs) >>> domain = Domain2D([rectangle1, rectangle2, rectangle3], [edge, vertex]) >>> domain.plot() """
[docs] def __init__(self, connection: List[Tuple[int, int, CurvePosition]]): self.connection = connection
@property def num_connected_subdomains(self) -> int: """Number of connected subdomains.""" return len(self.connection)
[docs] def validate(self, subdomains: List[Subdomain2D], tol: float = 1e-6): """ Validate that all specified subdomains, curves, and positions share the given vertex. Args: subdomains (List[Subdomain2D]): List of subdomains in the domain. tol (float): Tolerance for point-wise comparison Raises: ValueError: If the vertex connections are not consistent. """ self._validate_idxs(subdomains) if self.num_connected_subdomains < 2: raise ValueError("Vertex connection must have at least two connected subdomains.") curve0 = subdomains[self.connection[0][0]].curves[self.connection[0][1]] point0 = curve0.get_start() if self.connection[0][2] == "start" else curve0.get_end() for subdomain_idx, curve_idx, position in self.connection: curve = subdomains[subdomain_idx].curves[curve_idx] point = curve.get_start() if position == "start" else curve.get_end() if not np.allclose(point, point0, atol=tol): raise ValueError( f"Subdomain {subdomain_idx}, curve {curve_idx}, {position} point {point} " f"does not match the vertex {point0}." )
[docs] def get_connection_pairs( self, ) -> List[Tuple[Tuple[int, int], Tuple[int, int], Tuple[CurvePosition, CurvePosition]]]: """ Get all unique pairs of connected subdomains with their curve indices and positions. Returns: List[Tuple[Tuple[int, int], Tuple[int, int], Tuple[CurvePosition, CurvePosition]]]: List of connected subdomain pairs with the indexing [(subdomain0, subdomain1), (curve0, curve1), (position0, position1)]. """ pairs = [] n = len(self.connection) for i in range(n): for j in range(i + 1, n): subd1, curve_idx1, curve_pos1 = self.connection[i] subd2, curve_idx2, curve_pos2 = self.connection[j] pairs.append(((subd1, subd2), (curve_idx1, curve_idx2), (curve_pos1, curve_pos2))) return pairs
[docs] def get_shared_vertex(self, subdomains: List[Subdomain2D]) -> np.ndarray: """ Get the shared vertex between the connected subdomains. Args: subdomains (List[Subdomain2D]): List of subdomains in the domain. Returns: np.ndarray: Shared vertex coordinates. """ self._validate_idxs(subdomains) curve0 = subdomains[self.connection[0][0]].curves[self.connection[0][1]] return curve0.get_start() if self.connection[0][2] == "start" else curve0.get_end()
def _validate_idxs(self, subdomains: List[Subdomain2D]): """Validate that the subdomain and curve indices are within bounds.""" for subdomain_idx, curve_idx, position in self.connection: if subdomain_idx >= len(subdomains): raise ValueError(f"Subdomain index {subdomain_idx} is out of bounds.") if curve_idx >= len(subdomains[subdomain_idx].curves): raise ValueError(f"Curve index {curve_idx} is out of bounds.") def __repr__(self): return f"VertexConnection2D({self.connection})"
[docs] class CurveConnection2D(SubdomainConnection2D): """ Initialize a curve connection between two subdomains. Only two subdomains can be connected by a curve. Args: subdomains_indices (Tuple[int, int]): A tuple of two subdomain indices that share a curve. curve_indices (Tuple[int, int]): A tuple of two curve indices in the respective subdomains. Example: >>> from ttfemesh.domain import RectangleFactory, CurveConnection2D, VertexConnection2D >>> from ttfemesh.domain import Domain2D >>> lower_left = (0, 0) >>> upper_right = (2, 1) >>> rectangle1 = RectangleFactory.create(lower_left, upper_right) >>> lower_left = (2, 0) >>> upper_right = (3, 1) >>> rectangle2 = RectangleFactory.create(lower_left, upper_right) >>> lower_left = (-2, 1) >>> upper_right = (0, 2) >>> rectangle3 = RectangleFactory.create(lower_left, upper_right) >>> domain_idxs = [0, 1] >>> curve_idxs = [1, 3] >>> edge = CurveConnection2D(domain_idxs, curve_idxs) >>> vertex_idxs = [(0, 3, "start"), (2, 0, "end")] >>> vertex = VertexConnection2D(vertex_idxs) >>> domain = Domain2D([rectangle1, rectangle2, rectangle3], [edge, vertex]) >>> domain.plot() """
[docs] def __init__(self, subdomains_indices: Tuple[int, int], curve_indices: Tuple[int, int]): self.subdomains_indices = subdomains_indices self.curve_indices = curve_indices
@property def num_connected_subdomains(self) -> int: """Number of connected subdomains.""" return 2
[docs] def validate(self, subdomains: List[Subdomain2D], num_points: int = 100, tol: float = 1e-6): """ Validate that the curves are approximately equal. Args: subdomains (List[Subdomain2D]): List of subdomains in the domain. num_points (int): Number of points to sample along the curve. tol (float): Tolerance for point-wise comparison. """ self._validate_idxs(subdomains) sub1_idx, sub2_idx = self.subdomains_indices curve1_idx, curve2_idx = self.curve_indices curve1 = subdomains[sub1_idx].curves[curve1_idx] curve2 = subdomains[sub2_idx].curves[curve2_idx] if not curve1.equals(curve2, num_points=num_points, tol=tol): raise ValueError( f"Curves {curve1_idx} of subdomain {sub1_idx}" f" and {curve2_idx} of subdomain {sub2_idx} are not equal." )
[docs] def get_shared_curve(self, subdomains: List[Subdomain2D]) -> Curve: """ Get the shared curve between the connected subdomains. Args: subdomains (List[Subdomain2D]): List of subdomains in the domain. Returns: Curve: Shared curve. """ self._validate_idxs(subdomains) sub1_idx, sub2_idx = self.subdomains_indices curve1_idx, curve2_idx = self.curve_indices return subdomains[sub1_idx].curves[curve1_idx]
def _validate_idxs(self, subdomains: List[Subdomain2D]): """Validate that the subdomain and curve indices are within bounds.""" sub1_idx, sub2_idx = self.subdomains_indices curve1_idx, curve2_idx = self.curve_indices if sub1_idx >= len(subdomains): raise ValueError(f"Subdomain index {sub1_idx} is out of bounds.") if sub2_idx >= len(subdomains): raise ValueError(f"Subdomain index {sub2_idx} is out of bounds.") if curve1_idx >= len(subdomains[sub1_idx].curves): raise ValueError(f"Curve index {curve1_idx} is out of bounds.") if curve2_idx >= len(subdomains[sub2_idx].curves): raise ValueError(f"Curve index {curve2_idx} is out of bounds.") def __repr__(self): return f"CurveConnection({self.subdomains_indices}, {self.curve_indices})"