Skip to content

Commit 555485f

Browse files
authored
Override linop instead of _evaluate_linop (#818)
* Override `linop` instead of `_evaluate_linop` In arithmetic fallback covariance functions, override `linop` instead of `_evaluate_linop`, because the latter has the problem that the input preprocessing destroys attributes of np.array subclasses. Keeping the option of np.array subclassing can be very useful for certain linop implementations, e.g. tensor product covariance functions.
1 parent e7f332e commit 555485f

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

src/probnum/randprocs/covfuncs/_arithmetic_fallbacks.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
from probnum import linops, utils
12-
from probnum.typing import NotImplementedType, ScalarLike
12+
from probnum.typing import ArrayLike, NotImplementedType, ScalarLike
1313

1414
from ._covariance_function import BinaryOperandType, CovarianceFunction
1515

@@ -55,8 +55,8 @@ def __init__(self, covfunc: CovarianceFunction, scalar: ScalarLike):
5555
def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray] = None) -> np.ndarray:
5656
return self._scalar * self._covfunc(x0, x1)
5757

58-
def _evaluate_linop(
59-
self, x0: np.ndarray, x1: Optional[np.ndarray]
58+
def linop(
59+
self, x0: ArrayLike, x1: Optional[ArrayLike] = None
6060
) -> linops.LinearOperator:
6161
return self._scalar * self._covfunc.linop(x0, x1)
6262

@@ -82,7 +82,6 @@ class SumCovarianceFunction(CovarianceFunction):
8282
"""
8383

8484
def __init__(self, *summands: CovarianceFunction):
85-
8685
if not all(
8786
(summand.input_shape == summands[0].input_shape)
8887
and (summand.output_shape_0 == summands[0].output_shape_0)
@@ -104,8 +103,8 @@ def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> np.ndarray:
104103
operator.add, (summand(x0, x1) for summand in self._summands)
105104
)
106105

107-
def _evaluate_linop(
108-
self, x0: np.ndarray, x1: Optional[np.ndarray]
106+
def linop(
107+
self, x0: ArrayLike, x1: Optional[ArrayLike] = None
109108
) -> linops.LinearOperator:
110109
return functools.reduce(
111110
operator.add, (summand.linop(x0, x1) for summand in self._summands)
@@ -151,7 +150,6 @@ class ProductCovarianceFunction(CovarianceFunction):
151150
"""
152151

153152
def __init__(self, *factors: CovarianceFunction):
154-
155153
if not all(
156154
(factor.input_shape == factors[0].input_shape)
157155
and (factor.output_shape_0 == factors[0].output_shape_0)

0 commit comments

Comments
 (0)