Skip to content

Commit 65f5604

Browse files
authored
Linop diagonal (#820)
* Add `LinearOperator.diagonal()` ... which computes the diagonal of the linear operator. Default implementation multiplies with unit vectors, and subclasses may use more efficient implementations, e.g. for Kronecker linear operators. * Add tests for `LinearOperator.diagonal()`
1 parent 555485f commit 65f5604

File tree

6 files changed

+72
-19
lines changed

6 files changed

+72
-19
lines changed

src/probnum/linops/_arithmetic_fallbacks.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self, linop: LinearOperator, scalar: ScalarLike):
4040
transpose=lambda: self._scalar * self._linop.T,
4141
inverse=self._inv,
4242
trace=lambda: self._scalar * self._linop.trace(),
43+
diagonal=lambda: self._scalar * self._linop.diagonal(),
4344
)
4445

4546
# Matrix properties
@@ -89,7 +90,6 @@ class SumLinearOperator(LambdaLinearOperator):
8990
"""Sum of linear operators."""
9091

9192
def __init__(self, *summands: LinearOperator):
92-
9393
if not all(summand.shape == summands[0].shape for summand in summands):
9494
raise ValueError("All summands must have the same shape.")
9595

@@ -113,6 +113,9 @@ def __init__(self, *summands: LinearOperator):
113113
trace=lambda: functools.reduce(
114114
operator.add, (summand.trace() for summand in self._summands)
115115
),
116+
diagonal=lambda: functools.reduce(
117+
operator.add, (summand.diagonal() for summand in self._summands)
118+
),
116119
)
117120

118121
# Matrix properties
@@ -176,7 +179,6 @@ class ProductLinearOperator(LambdaLinearOperator):
176179
"""(Operator) Product of linear operators."""
177180

178181
def __init__(self, *factors: LinearOperator):
179-
180182
if not all(
181183
lfactor.shape[1] == rfactor.shape[0]
182184
for lfactor, rfactor in zip(factors[:-1], factors[1:])

src/probnum/linops/_block.py

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ def _rank(self) -> np.intp:
132132
return np.sum([block.rank() for block in self.blocks])
133133
return super()._rank()
134134

135+
def _diagonal(self) -> np.ndarray:
136+
if self._all_blocks_square:
137+
return np.concatenate([block.diagonal() for block in self.blocks])
138+
return super()._diagonal()
139+
135140
def _cholesky(self, lower: bool) -> BlockDiagonalMatrix:
136141
if self._all_blocks_square:
137142
return BlockDiagonalMatrix(

src/probnum/linops/_kronecker.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ def _trace(self) -> np.number:
172172

173173
return super()._trace()
174174

175+
def _diagonal(self) -> np.ndarray:
176+
if self.B.is_square:
177+
return np.kron(self.A.diagonal(), self.B.diagonal())
178+
return super()._diagonal()
179+
175180
def _astype(
176181
self, dtype: DTypeLike, order: str, casting: str, copy: bool
177182
) -> "Kronecker":
@@ -200,7 +205,6 @@ def _matmul_kronecker(self, other: "Kronecker") -> "Kronecker":
200205
def _add_kronecker(
201206
self, other: "Kronecker"
202207
) -> Union[NotImplementedType, "Kronecker"]:
203-
204208
if self.A is other.A or self.A == other.A:
205209
return Kronecker(A=self.A, B=self.B + other.B)
206210

@@ -212,7 +216,6 @@ def _add_kronecker(
212216
def _sub_kronecker(
213217
self, other: "Kronecker"
214218
) -> Union[NotImplementedType, "Kronecker"]:
215-
216219
if self.A is other.A or self.A == other.A:
217220
return Kronecker(A=self.A, B=self.B - other.B)
218221

@@ -537,7 +540,6 @@ def _matmul_idkronecker(self, other: "IdentityKronecker") -> "IdentityKronecker"
537540
def _add_idkronecker(
538541
self, other: "IdentityKronecker"
539542
) -> Union[NotImplementedType, "IdentityKronecker"]:
540-
541543
if self.A.shape == other.A.shape:
542544
return IdentityKronecker(num_blocks=self._num_blocks, B=self.B + other.B)
543545

@@ -546,7 +548,6 @@ def _add_idkronecker(
546548
def _sub_idkronecker(
547549
self, other: "IdentityKronecker"
548550
) -> Union[NotImplementedType, "IdentityKronecker"]:
549-
550551
if self.A.shape == other.A.shape:
551552
return IdentityKronecker(num_blocks=self._num_blocks, B=self.B - other.B)
552553

src/probnum/linops/_linear_operator.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
self._det_cache = None
110110
self._logabsdet_cache = None
111111
self._trace_cache = None
112+
self._diagonal_cache = None
112113

113114
self._lu_cache = None
114115
self._cholesky_cache = None
@@ -737,19 +738,7 @@ def _trace(self) -> np.number:
737738
trace : float
738739
Trace of the linear operator.
739740
"""
740-
741-
vec = np.zeros(self.shape[1], dtype=self.dtype)
742-
743-
vec[0] = 1
744-
trace = (self @ vec)[0]
745-
vec[0] = 0
746-
747-
for i in range(1, self.shape[0]):
748-
vec[i] = 1
749-
trace += (self @ vec)[i]
750-
vec[i] = 0
751-
752-
return trace
741+
return np.sum(self.diagonal())
753742

754743
def trace(self) -> np.number:
755744
r"""Trace of the linear operator.
@@ -777,6 +766,31 @@ def trace(self) -> np.number:
777766

778767
return self._trace_cache
779768

769+
def _diagonal(self) -> np.ndarray:
770+
"""Diagonal of the linear operator.
771+
772+
You may implement this method in a subclass.
773+
"""
774+
D = np.min(self.shape)
775+
diag = np.zeros(D, dtype=self.dtype)
776+
vec = np.zeros(self.shape[1], dtype=self.dtype)
777+
778+
for i in range(D):
779+
vec[i] = 1
780+
diag[i] = (self @ vec)[i]
781+
vec[i] = 0
782+
783+
return diag
784+
785+
def diagonal(self) -> np.ndarray:
786+
"""Diagonal of the linear operator."""
787+
if self._diagonal_cache is None:
788+
self._diagonal_cache = self._diagonal()
789+
790+
self._diagonal_cache.setflags(write=False)
791+
792+
return self._diagonal_cache
793+
780794
####################################################################################
781795
# Matrix Decompositions
782796
####################################################################################
@@ -1337,6 +1351,7 @@ def __init__(
13371351
det: Optional[Callable[[], np.inexact]] = None,
13381352
logabsdet: Optional[Callable[[], np.floating]] = None,
13391353
trace: Optional[Callable[[], np.number]] = None,
1354+
diagonal: Optional[Callable[[], np.ndarray]] = None,
13401355
):
13411356
super().__init__(shape, dtype)
13421357

@@ -1357,6 +1372,7 @@ def __init__(
13571372
self._det_fn = det
13581373
self._logabsdet_fn = logabsdet
13591374
self._trace_fn = trace
1375+
self._diagonal_fn = diagonal
13601376

13611377
def _matmul(self, x: np.ndarray) -> np.ndarray:
13621378
return self._matmul_fn(x)
@@ -1429,6 +1445,12 @@ def _trace(self) -> np.number:
14291445

14301446
return self._trace_fn()
14311447

1448+
def _diagonal(self) -> np.ndarray:
1449+
if self._diagonal_fn is None:
1450+
return super()._diagonal()
1451+
1452+
return self._diagonal_fn()
1453+
14321454

14331455
class TransposedLinearOperator(LambdaLinearOperator):
14341456
"""Transposition of a linear operator."""
@@ -1457,6 +1479,7 @@ def __init__(
14571479
det=self._linop.det,
14581480
logabsdet=self._linop.logabsdet,
14591481
trace=self._linop.trace,
1482+
diagonal=self._linop.diagonal,
14601483
)
14611484

14621485
def _astype(
@@ -1561,6 +1584,7 @@ def __init__(
15611584
det=lambda: self._linop.det().astype(self._inexact_dtype),
15621585
logabsdet=lambda: self._linop.logabsdet().astype(self._inexact_dtype),
15631586
trace=lambda: self._linop.trace().astype(dtype),
1587+
diagonal=lambda: self._linop.diagonal().astype(dtype),
15641588
)
15651589

15661590
def _astype(
@@ -1591,20 +1615,23 @@ def __init__(self, A: Union[ArrayLike, scipy.sparse.spmatrix]):
15911615
matmul = LinearOperator.broadcast_matmat(lambda x: self.A @ x)
15921616
todense = self.A.toarray
15931617
trace = lambda: self.A.diagonal().sum()
1618+
diagonal = self.A.diagonal
15941619
else:
15951620
self.A = np.asarray(A)
15961621
self.A.setflags(write=False)
15971622

15981623
matmul = lambda x: self.A @ x
15991624
todense = lambda: self.A
16001625
trace = lambda: np.trace(self.A)
1626+
diagonal = lambda: np.diagonal(self.A)
16011627

16021628
super().__init__(
16031629
self.A.shape,
16041630
self.A.dtype,
16051631
matmul=matmul,
16061632
todense=todense,
16071633
trace=trace,
1634+
diagonal=diagonal,
16081635
)
16091636

16101637
def _transpose(self) -> "Matrix":
@@ -1691,6 +1718,7 @@ def __init__(
16911718
trace=lambda: probnum.utils.as_numpy_scalar(
16921719
self.shape[0], dtype=self.dtype
16931720
),
1721+
diagonal=lambda: np.ones(shape[0], dtype=self.dtype),
16941722
)
16951723

16961724
# Matrix properties

src/probnum/linops/_scaling.py

+5
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ def _cond_isotropic(self, p: Union[None, int, float, str]) -> np.floating:
312312

313313
return np.linalg.cond(self.todense(cache=False), p=p)
314314

315+
def _diagonal(self) -> np.ndarray:
316+
return self.factors
317+
315318
def _cholesky(self, lower: bool = True) -> Scaling:
316319
if self._scalar is not None:
317320
if self._scalar <= 0:
@@ -347,6 +350,7 @@ def __init__(self, shape, dtype=np.float64):
347350
det = lambda: np.zeros(shape=(), dtype=dtype)
348351

349352
trace = lambda: np.zeros(shape=(), dtype=dtype)
353+
diagonal = lambda: np.zeros(shape=(np.min(shape),), dtype=dtype)
350354

351355
def matmul(x: np.ndarray) -> np.ndarray:
352356
target_shape = list(x.shape)
@@ -363,6 +367,7 @@ def matmul(x: np.ndarray) -> np.ndarray:
363367
eigvals=eigvals,
364368
det=det,
365369
trace=trace,
370+
diagonal=diagonal,
366371
)
367372

368373
# Matrix properties

tests/test_linops/test_linops.py

+12
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,18 @@ def test_trace(linop: pn.linops.LinearOperator, matrix: np.ndarray):
381381
linop.trace()
382382

383383

384+
@pytest_cases.parametrize_with_cases("linop,matrix", cases=case_modules)
385+
def test_diagonal(linop: pn.linops.LinearOperator, matrix: np.ndarray):
386+
linop_diagonal = linop.diagonal()
387+
matrix_diagonal = np.diagonal(matrix)
388+
389+
assert isinstance(linop_diagonal, np.ndarray)
390+
assert linop_diagonal.shape == matrix_diagonal.shape
391+
assert linop_diagonal.dtype == matrix_diagonal.dtype
392+
393+
np.testing.assert_allclose(linop_diagonal, matrix_diagonal)
394+
395+
384396
@pytest_cases.parametrize_with_cases("linop,matrix", cases=case_modules)
385397
def test_transpose(linop: pn.linops.LinearOperator, matrix: np.ndarray):
386398
matrix_transpose = matrix.transpose()

0 commit comments

Comments
 (0)