Skip to content

Commit da1f189

Browse files
authored
Break OvertypedType -> BlockVariable -> OverloadedType reference cycle (#194)
* Break OvertypedType -> BlockVariable -> OverloadedType reference cycle * Add OverloadedType.clear_block_variable * OverloadedType pickle fix, required by Firedrake EnsembleReducedFunctional
1 parent 2c98d4b commit da1f189

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

pyadjoint/overloaded_type.py

+49-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import weakref
12
from .block_variable import BlockVariable
23
from .tape import get_working_tape
34

@@ -64,6 +65,39 @@ def register_overloaded_type(overloaded_type, classes=None):
6465
return overloaded_type
6566

6667

68+
class Weakref:
69+
"""Weakref which is picklable if the referenced object is picklable or
70+
None.
71+
72+
Args:
73+
obj (:obj:`object`): The object to hold a weak reference to. None
74+
indicates a reference to no object.
75+
"""
76+
77+
def __init__(self, obj=None):
78+
self._init(obj)
79+
80+
def _init(self, obj):
81+
if obj is None:
82+
self._obj = lambda: None
83+
else:
84+
self._obj = weakref.ref(obj)
85+
86+
def __call__(self):
87+
return self._obj()
88+
89+
def __getstate__(self):
90+
state = self.__dict__.copy()
91+
state["_obj"] = self()
92+
return state
93+
94+
def __setstate__(self, state):
95+
state = state.copy()
96+
obj = state.pop("_obj")
97+
self.__dict__.update(state)
98+
self._init(obj)
99+
100+
67101
class OverloadedType(object):
68102
"""Base class for OverloadedType types.
69103
@@ -74,8 +108,7 @@ class OverloadedType(object):
74108
"""
75109

76110
def __init__(self, *args, **kwargs):
77-
self.block_variable = None
78-
self.create_block_variable()
111+
self.clear_block_variable()
79112

80113
@classmethod
81114
def _ad_init_object(cls, obj):
@@ -93,9 +126,21 @@ def _ad_init_object(cls, obj):
93126
"""
94127
return cls(obj)
95128

129+
@property
130+
def block_variable(self):
131+
block_variable = self._block_variable()
132+
return self.create_block_variable() if block_variable is None else block_variable
133+
134+
@block_variable.setter
135+
def block_variable(self, value):
136+
self._block_variable = Weakref(value)
137+
138+
def clear_block_variable(self):
139+
self._block_variable = Weakref()
140+
96141
def create_block_variable(self):
97-
self.block_variable = BlockVariable(self)
98-
return self.block_variable
142+
self.block_variable = block_variable = BlockVariable(self)
143+
return block_variable
99144

100145
def _ad_convert_type(self, value, options={}):
101146
"""This method must be overridden.

pyadjoint/tape.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ def __exit__(self, *args):
122122
_annotation_enabled = self._orig_annotation_enabled.pop()
123123
if self.modifies is not None:
124124
try:
125-
self.modifies.create_block_variable()
125+
self.modifies.clear_block_variable()
126126
except AttributeError:
127127
for var in self.modifies:
128-
var.create_block_variable()
128+
var.clear_block_variable()
129129

130130

131131
no_annotations = stop_annotating()

0 commit comments

Comments
 (0)