Skip to content

capturegraph.scheduling.distance.batch #

Batch Distance Computation#

Provides a protocol and utilities for batch (vectorized) distance computation. This enables O(n²) pairwise distance matrices to be computed efficiently using NumPy broadcasting instead of Python loops.

Each distance function can optionally provide batch capabilities via the BatchDistanceFunction protocol.

BatchedDistanceFunction #

Bases: ABC

Protocol for distance functions that support batch computation.

Source code in capturegraph-lib/capturegraph/scheduling/distance/batch.py
class BatchedDistanceFunction(ABC):
    """Protocol for distance functions that support batch computation."""

    @abstractmethod
    def __call__(self, a: Any, b: Any) -> float:
        """Compute distance between two items (single pair)."""
        ...

    @abstractmethod
    def extract(self, sessions: cg.List[cg.Dict[Any]]) -> np.ndarray:
        """Extract numeric features from sessions as a 2D array."""
        ...

    @abstractmethod
    def pairwise(self, features_a: np.ndarray, features_b: np.ndarray) -> np.ndarray:
        """Compute pairwise distances from pre-extracted features."""
        ...

    def matrix(
        self,
        vals_a: cg.List[Any],
        vals_b: cg.List[Any] | None = None,
    ) -> np.ndarray:
        """Compute pairwise distance matrix."""
        if vals_b is None:
            vals_b = vals_a

        # Handle empty lists - return empty matrix
        n = len(vals_a)
        m = len(vals_b)
        if n == 0 or m == 0:
            return np.zeros((n, m), dtype=np.float64)

        # Check for missing values
        mask_a = np.array([not cg.is_missing(v) for v in vals_a])
        mask_b = np.array([not cg.is_missing(v) for v in vals_b])

        if not np.any(mask_a) or not np.any(mask_b):
            raise ValueError("All values are missing")

        # Extract features only for non-missing sessions
        valid_a = cg.List([v for v, m in zip(vals_a, mask_a) if m])
        valid_b = cg.List([v for v, m in zip(vals_b, mask_b) if m])

        features_a = self.extract(valid_a)
        features_b = self.extract(valid_b)

        # Compute pairwise distances for valid pairs
        valid_distances = self.pairwise(features_a, features_b)

        # Map back to full matrix (missing pairs get 0 contribution)
        full_distances = np.zeros((len(vals_a), len(vals_b)), dtype=np.float64)
        valid_i = np.where(mask_a)[0]
        valid_j = np.where(mask_b)[0]
        full_distances[np.ix_(valid_i, valid_j)] = valid_distances
        return full_distances

__call__(a, b) abstractmethod #

Compute distance between two items (single pair).

Source code in capturegraph-lib/capturegraph/scheduling/distance/batch.py
@abstractmethod
def __call__(self, a: Any, b: Any) -> float:
    """Compute distance between two items (single pair)."""
    ...

extract(sessions) abstractmethod #

Extract numeric features from sessions as a 2D array.

Source code in capturegraph-lib/capturegraph/scheduling/distance/batch.py
@abstractmethod
def extract(self, sessions: cg.List[cg.Dict[Any]]) -> np.ndarray:
    """Extract numeric features from sessions as a 2D array."""
    ...

pairwise(features_a, features_b) abstractmethod #

Compute pairwise distances from pre-extracted features.

Source code in capturegraph-lib/capturegraph/scheduling/distance/batch.py
@abstractmethod
def pairwise(self, features_a: np.ndarray, features_b: np.ndarray) -> np.ndarray:
    """Compute pairwise distances from pre-extracted features."""
    ...

matrix(vals_a, vals_b=None) #

Compute pairwise distance matrix.

Source code in capturegraph-lib/capturegraph/scheduling/distance/batch.py
def matrix(
    self,
    vals_a: cg.List[Any],
    vals_b: cg.List[Any] | None = None,
) -> np.ndarray:
    """Compute pairwise distance matrix."""
    if vals_b is None:
        vals_b = vals_a

    # Handle empty lists - return empty matrix
    n = len(vals_a)
    m = len(vals_b)
    if n == 0 or m == 0:
        return np.zeros((n, m), dtype=np.float64)

    # Check for missing values
    mask_a = np.array([not cg.is_missing(v) for v in vals_a])
    mask_b = np.array([not cg.is_missing(v) for v in vals_b])

    if not np.any(mask_a) or not np.any(mask_b):
        raise ValueError("All values are missing")

    # Extract features only for non-missing sessions
    valid_a = cg.List([v for v, m in zip(vals_a, mask_a) if m])
    valid_b = cg.List([v for v, m in zip(vals_b, mask_b) if m])

    features_a = self.extract(valid_a)
    features_b = self.extract(valid_b)

    # Compute pairwise distances for valid pairs
    valid_distances = self.pairwise(features_a, features_b)

    # Map back to full matrix (missing pairs get 0 contribution)
    full_distances = np.zeros((len(vals_a), len(vals_b)), dtype=np.float64)
    valid_i = np.where(mask_a)[0]
    valid_j = np.where(mask_b)[0]
    full_distances[np.ix_(valid_i, valid_j)] = valid_distances
    return full_distances

batch_matrix(fn, vals_a, vals_b=None) #

Call a batchable distance function.

Source code in capturegraph-lib/capturegraph/scheduling/distance/batch.py
def batch_matrix(
    fn: Callable,
    vals_a: cg.List[Any],
    vals_b: cg.List[Any] | None = None,
) -> np.ndarray:
    """Call a batchable distance function."""
    if vals_b is None:
        vals_b = vals_a

    # Handle empty lists - return empty matrix
    n = len(vals_a)
    m = len(vals_b)
    if n == 0 or m == 0:
        return np.zeros((n, m), dtype=np.float64)

    if hasattr(fn, "matrix") and callable(fn.matrix):
        return fn.matrix(vals_a, vals_b)
    else:
        full_distances = np.zeros((n, m), dtype=np.float64)
        for i, va in enumerate(vals_a):
            for j, vb in enumerate(vals_b):
                if not cg.is_missing(va) and not cg.is_missing(vb):
                    try:
                        full_distances[i, j] = fn(va, vb)
                    except Exception:
                        pass
        return full_distances