1
+ import weakref
1
2
from .block_variable import BlockVariable
2
3
from .tape import get_working_tape
3
4
@@ -64,6 +65,39 @@ def register_overloaded_type(overloaded_type, classes=None):
64
65
return overloaded_type
65
66
66
67
68
+ class Weakref :
69
+ """Weakref which is picklable if the referenced object is picklable or
70
+ None.
71
+
72
+ Args:
73
+ obj (:obj:`object`): The object to hold a weak reference to. None
74
+ indicates a reference to no object.
75
+ """
76
+
77
+ def __init__ (self , obj = None ):
78
+ self ._init (obj )
79
+
80
+ def _init (self , obj ):
81
+ if obj is None :
82
+ self ._obj = lambda : None
83
+ else :
84
+ self ._obj = weakref .ref (obj )
85
+
86
+ def __call__ (self ):
87
+ return self ._obj ()
88
+
89
+ def __getstate__ (self ):
90
+ state = self .__dict__ .copy ()
91
+ state ["_obj" ] = self ()
92
+ return state
93
+
94
+ def __setstate__ (self , state ):
95
+ state = state .copy ()
96
+ obj = state .pop ("_obj" )
97
+ self .__dict__ .update (state )
98
+ self ._init (obj )
99
+
100
+
67
101
class OverloadedType (object ):
68
102
"""Base class for OverloadedType types.
69
103
@@ -74,8 +108,7 @@ class OverloadedType(object):
74
108
"""
75
109
76
110
def __init__ (self , * args , ** kwargs ):
77
- self .block_variable = None
78
- self .create_block_variable ()
111
+ self .clear_block_variable ()
79
112
80
113
@classmethod
81
114
def _ad_init_object (cls , obj ):
@@ -93,9 +126,21 @@ def _ad_init_object(cls, obj):
93
126
"""
94
127
return cls (obj )
95
128
129
+ @property
130
+ def block_variable (self ):
131
+ block_variable = self ._block_variable ()
132
+ return self .create_block_variable () if block_variable is None else block_variable
133
+
134
+ @block_variable .setter
135
+ def block_variable (self , value ):
136
+ self ._block_variable = Weakref (value )
137
+
138
+ def clear_block_variable (self ):
139
+ self ._block_variable = Weakref ()
140
+
96
141
def create_block_variable (self ):
97
- self .block_variable = BlockVariable (self )
98
- return self . block_variable
142
+ self .block_variable = block_variable = BlockVariable (self )
143
+ return block_variable
99
144
100
145
def _ad_convert_type (self , value , options = {}):
101
146
"""This method must be overridden.
0 commit comments