@@ -109,6 +109,7 @@ def __init__(
109
109
self ._det_cache = None
110
110
self ._logabsdet_cache = None
111
111
self ._trace_cache = None
112
+ self ._diagonal_cache = None
112
113
113
114
self ._lu_cache = None
114
115
self ._cholesky_cache = None
@@ -737,19 +738,7 @@ def _trace(self) -> np.number:
737
738
trace : float
738
739
Trace of the linear operator.
739
740
"""
740
-
741
- vec = np .zeros (self .shape [1 ], dtype = self .dtype )
742
-
743
- vec [0 ] = 1
744
- trace = (self @ vec )[0 ]
745
- vec [0 ] = 0
746
-
747
- for i in range (1 , self .shape [0 ]):
748
- vec [i ] = 1
749
- trace += (self @ vec )[i ]
750
- vec [i ] = 0
751
-
752
- return trace
741
+ return np .sum (self .diagonal ())
753
742
754
743
def trace (self ) -> np .number :
755
744
r"""Trace of the linear operator.
@@ -777,6 +766,31 @@ def trace(self) -> np.number:
777
766
778
767
return self ._trace_cache
779
768
769
+ def _diagonal (self ) -> np .ndarray :
770
+ """Diagonal of the linear operator.
771
+
772
+ You may implement this method in a subclass.
773
+ """
774
+ D = np .min (self .shape )
775
+ diag = np .zeros (D , dtype = self .dtype )
776
+ vec = np .zeros (self .shape [1 ], dtype = self .dtype )
777
+
778
+ for i in range (D ):
779
+ vec [i ] = 1
780
+ diag [i ] = (self @ vec )[i ]
781
+ vec [i ] = 0
782
+
783
+ return diag
784
+
785
+ def diagonal (self ) -> np .ndarray :
786
+ """Diagonal of the linear operator."""
787
+ if self ._diagonal_cache is None :
788
+ self ._diagonal_cache = self ._diagonal ()
789
+
790
+ self ._diagonal_cache .setflags (write = False )
791
+
792
+ return self ._diagonal_cache
793
+
780
794
####################################################################################
781
795
# Matrix Decompositions
782
796
####################################################################################
@@ -1337,6 +1351,7 @@ def __init__(
1337
1351
det : Optional [Callable [[], np .inexact ]] = None ,
1338
1352
logabsdet : Optional [Callable [[], np .floating ]] = None ,
1339
1353
trace : Optional [Callable [[], np .number ]] = None ,
1354
+ diagonal : Optional [Callable [[], np .ndarray ]] = None ,
1340
1355
):
1341
1356
super ().__init__ (shape , dtype )
1342
1357
@@ -1357,6 +1372,7 @@ def __init__(
1357
1372
self ._det_fn = det
1358
1373
self ._logabsdet_fn = logabsdet
1359
1374
self ._trace_fn = trace
1375
+ self ._diagonal_fn = diagonal
1360
1376
1361
1377
def _matmul (self , x : np .ndarray ) -> np .ndarray :
1362
1378
return self ._matmul_fn (x )
@@ -1429,6 +1445,12 @@ def _trace(self) -> np.number:
1429
1445
1430
1446
return self ._trace_fn ()
1431
1447
1448
+ def _diagonal (self ) -> np .ndarray :
1449
+ if self ._diagonal_fn is None :
1450
+ return super ()._diagonal ()
1451
+
1452
+ return self ._diagonal_fn ()
1453
+
1432
1454
1433
1455
class TransposedLinearOperator (LambdaLinearOperator ):
1434
1456
"""Transposition of a linear operator."""
@@ -1457,6 +1479,7 @@ def __init__(
1457
1479
det = self ._linop .det ,
1458
1480
logabsdet = self ._linop .logabsdet ,
1459
1481
trace = self ._linop .trace ,
1482
+ diagonal = self ._linop .diagonal ,
1460
1483
)
1461
1484
1462
1485
def _astype (
@@ -1561,6 +1584,7 @@ def __init__(
1561
1584
det = lambda : self ._linop .det ().astype (self ._inexact_dtype ),
1562
1585
logabsdet = lambda : self ._linop .logabsdet ().astype (self ._inexact_dtype ),
1563
1586
trace = lambda : self ._linop .trace ().astype (dtype ),
1587
+ diagonal = lambda : self ._linop .diagonal ().astype (dtype ),
1564
1588
)
1565
1589
1566
1590
def _astype (
@@ -1591,20 +1615,23 @@ def __init__(self, A: Union[ArrayLike, scipy.sparse.spmatrix]):
1591
1615
matmul = LinearOperator .broadcast_matmat (lambda x : self .A @ x )
1592
1616
todense = self .A .toarray
1593
1617
trace = lambda : self .A .diagonal ().sum ()
1618
+ diagonal = self .A .diagonal
1594
1619
else :
1595
1620
self .A = np .asarray (A )
1596
1621
self .A .setflags (write = False )
1597
1622
1598
1623
matmul = lambda x : self .A @ x
1599
1624
todense = lambda : self .A
1600
1625
trace = lambda : np .trace (self .A )
1626
+ diagonal = lambda : np .diagonal (self .A )
1601
1627
1602
1628
super ().__init__ (
1603
1629
self .A .shape ,
1604
1630
self .A .dtype ,
1605
1631
matmul = matmul ,
1606
1632
todense = todense ,
1607
1633
trace = trace ,
1634
+ diagonal = diagonal ,
1608
1635
)
1609
1636
1610
1637
def _transpose (self ) -> "Matrix" :
@@ -1691,6 +1718,7 @@ def __init__(
1691
1718
trace = lambda : probnum .utils .as_numpy_scalar (
1692
1719
self .shape [0 ], dtype = self .dtype
1693
1720
),
1721
+ diagonal = lambda : np .ones (shape [0 ], dtype = self .dtype ),
1694
1722
)
1695
1723
1696
1724
# Matrix properties
0 commit comments