Skip to content

Commit 2c98d4b

Browse files
authored
Convert TimeStep block variables to lists (#197)
* Maintain order of blocks in TimeStep with an OrderedSet This introduces a new set implementation, OrderedSet. * Use dictionary as backing for OrderedSet * Use Python 3.10 compatible forward reference
1 parent 49430e8 commit 2c98d4b

File tree

2 files changed

+139
-2
lines changed

2 files changed

+139
-2
lines changed

pyadjoint/ordered_set.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from collections.abc import Hashable, Iterable, Iterator, MutableSet, Set
2+
3+
4+
class OrderedSet(MutableSet):
5+
def __init__(self, iterable: Iterable[Hashable] = ()):
6+
"""
7+
An OrderedSet is a set that maintains the insertion order
8+
of its elements.
9+
10+
Args:
11+
iterable: An iterable with which to initialise the set elements.
12+
"""
13+
14+
self._elements: dict[Hashable, None] = {}
15+
16+
for element in iterable:
17+
self._elements[element] = None
18+
19+
def __contains__(self, obj: Hashable) -> bool:
20+
return obj in self._elements
21+
22+
def __len__(self) -> int:
23+
return len(self._elements)
24+
25+
def copy(self) -> "OrderedSet":
26+
"""Return a shallow copy of this set"""
27+
28+
ret = type(self)()
29+
ret._elements = self._elements.copy()
30+
31+
return ret
32+
33+
def add(self, obj: Hashable) -> None:
34+
"""Add obj to this set."""
35+
36+
self._elements[obj] = None
37+
38+
def remove(self, obj: Hashable) -> None:
39+
"""Remove obj from this set, it must be a member."""
40+
41+
if obj not in self:
42+
raise KeyError(obj)
43+
44+
del self._elements[obj]
45+
46+
def discard(self, obj: Hashable) -> None:
47+
"""Remove obj from this set if it is present."""
48+
49+
if obj not in self:
50+
return
51+
52+
self.remove(obj)
53+
54+
def __iter__(self) -> Iterator:
55+
return iter(self._elements.keys())
56+
57+
def union(self, *others: Iterable[Hashable]) -> "OrderedSet":
58+
"""Return a new set with elements from this set and others."""
59+
60+
ret = self.copy()
61+
62+
for other in others:
63+
for element in other:
64+
ret.add(element)
65+
66+
return ret
67+
68+
def __or__(self, other: Set) -> "OrderedSet":
69+
if not isinstance(other, Set):
70+
raise TypeError
71+
72+
return self.union(other)
73+
74+
def difference(self, *others: Iterable[Hashable]) -> "OrderedSet":
75+
"""Return a new set with elements from this set
76+
that are not in others.
77+
"""
78+
79+
ret = self.copy()
80+
81+
for other in others:
82+
for element in other:
83+
ret.discard(element)
84+
85+
return ret
86+
87+
def __sub__(self, other: Set) -> "OrderedSet":
88+
if not isinstance(other, Set):
89+
raise TypeError
90+
91+
return self.difference(other)
92+
93+
def intersection(self, *others: Iterable[Hashable]) -> "OrderedSet":
94+
"""Return a new set with elements common to this set
95+
and all others.
96+
"""
97+
98+
ret = type(self)()
99+
100+
for element in self:
101+
for other in others:
102+
if element not in other:
103+
break
104+
else:
105+
ret.add(element)
106+
107+
return ret
108+
109+
def __and__(self, other: Set) -> "OrderedSet":
110+
if not isinstance(other, Set):
111+
raise TypeError
112+
113+
return self.intersection(other)
114+
115+
def symmetric_difference(self, other: Iterable[Hashable]) -> "OrderedSet":
116+
"""Return a new set with elements either in this set or other,
117+
but not both.
118+
"""
119+
120+
ret = type(self)()
121+
122+
for element in self:
123+
if element not in other:
124+
ret.add(element)
125+
126+
for element in other:
127+
if element not in self:
128+
ret.add(element)
129+
130+
return ret
131+
132+
def __xor__(self, other: Set) -> "OrderedSet":
133+
if not isinstance(other, Set):
134+
raise TypeError
135+
136+
return self.symmetric_difference(other)

pyadjoint/tape.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional, Iterable
88
from abc import ABC, abstractmethod
99
from .checkpointing import CheckpointManager, CheckpointError, StorageType
10+
from .ordered_set import OrderedSet
1011

1112
_working_tape = None
1213
_annotation_enabled = False
@@ -788,8 +789,8 @@ def __init__(self, blocks=()):
788789
super().__init__(blocks)
789790
# The set of block variables which are needed to restart from the start
790791
# of this timestep.
791-
self.checkpointable_state = set()
792-
self.adjoint_dependencies = set()
792+
self.checkpointable_state = OrderedSet()
793+
self.adjoint_dependencies = OrderedSet()
793794
# A dictionary mapping the block variables in the checkpointable state
794795
# to their checkpoint values.
795796
self._checkpoint = {}

0 commit comments

Comments
 (0)