Skip to content

Commit 43b4032

Browse files
authored
Merge pull request #152 from dolfin-adjoint/explog
Log and exp for adjfloats
2 parents 34511f9 + 2570662 commit 43b4032

File tree

3 files changed

+125
-3
lines changed

3 files changed

+125
-3
lines changed

pyadjoint/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .tape import (Tape,
1212
set_working_tape, get_working_tape, no_annotations,
1313
annotate_tape, stop_annotating, pause_annotation, continue_annotation)
14-
from .adjfloat import AdjFloat
14+
from .adjfloat import AdjFloat, exp, log
1515
from .reduced_functional import ReducedFunctional
1616
from .drivers import compute_gradient, compute_hessian, solve_adjoint
1717
from .verification import taylor_test, taylor_to_dict

pyadjoint/adjfloat.py

+94
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from functools import wraps
12
from .block import Block
23
from .overloaded_type import OverloadedType, register_overloaded_type, create_overloaded_object
34
from .tape import get_working_tape, annotate_tape, stop_annotating
5+
import math
46

57

68
def annotate_operator(operator):
@@ -129,6 +131,98 @@ def _ad_str(self):
129131
return str(self.block_variable.saved_output)
130132

131133

134+
_exp = math.exp
135+
_log = math.log
136+
137+
138+
@wraps(_exp)
139+
def exp(a, **kwargs):
140+
annotate = annotate_tape(kwargs)
141+
if annotate:
142+
a = create_overloaded_object(a)
143+
144+
block = ExpBlock(a)
145+
tape = get_working_tape()
146+
tape.add_block(block)
147+
148+
with stop_annotating():
149+
out = _exp(a)
150+
out = AdjFloat(out)
151+
152+
if annotate:
153+
block.add_output(out.block_variable)
154+
return out
155+
156+
157+
def log(a, **kwargs):
158+
"""Return the natural logarithm of a."""
159+
annotate = annotate_tape(kwargs)
160+
if annotate:
161+
a = create_overloaded_object(a)
162+
163+
block = LogBlock(a)
164+
tape = get_working_tape()
165+
tape.add_block(block)
166+
167+
with stop_annotating():
168+
out = _log(a)
169+
out = AdjFloat(out)
170+
171+
if annotate:
172+
block.add_output(out.block_variable)
173+
return out
174+
175+
176+
class ExpBlock(Block):
177+
def __init__(self, a):
178+
super().__init__()
179+
self.add_dependency(a)
180+
181+
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None):
182+
adj_input = adj_inputs[0]
183+
input0 = inputs[0]
184+
return _exp(input0) * adj_input
185+
186+
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
187+
tlm_input = tlm_inputs[0]
188+
input0 = inputs[0]
189+
return _exp(input0) * tlm_input
190+
191+
def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
192+
relevant_dependencies, prepared=None):
193+
input0 = inputs[0]
194+
hessian = hessian_inputs[0]
195+
return _exp(input0) * hessian
196+
197+
def recompute_component(self, inputs, block_variable, idx, prepared):
198+
return _exp(inputs[0])
199+
200+
201+
class LogBlock(Block):
202+
def __init__(self, a):
203+
super().__init__()
204+
self.add_dependency(a)
205+
206+
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None):
207+
adj_input = adj_inputs[0]
208+
input0 = inputs[0]
209+
return adj_input / input0
210+
211+
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
212+
tlm_input = tlm_inputs[0]
213+
input0 = inputs[0]
214+
return tlm_input / input0
215+
216+
def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
217+
relevant_dependencies, prepared=None):
218+
input0 = inputs[0]
219+
hessian = hessian_inputs[0]
220+
return -hessian / input0 / input0
221+
222+
def recompute_component(self, inputs, block_variable, idx, prepared):
223+
return _log(inputs[0])
224+
225+
132226
_min = min
133227
_max = max
134228

tests/pyadjoint/test_floats.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from math import log
2+
import math
33
from numpy.testing import assert_approx_equal
44
from numpy.random import rand
55
from pyadjoint import *
@@ -155,6 +155,34 @@ def test_float_neg():
155155
assert rf2.derivative() == - 2.0
156156

157157

158+
def test_float_logexp():
159+
a = AdjFloat(3.0)
160+
b = exp(a)
161+
c = log(b)
162+
assert_approx_equal(c, 3.0)
163+
164+
b = log(a)
165+
c = exp(b)
166+
assert c, 3.0
167+
168+
rf = ReducedFunctional(c, Control(a))
169+
assert_approx_equal(rf(a), 3.0)
170+
assert_approx_equal(rf(AdjFloat(1.0)), 1.0)
171+
assert_approx_equal(rf(AdjFloat(9.0)), 9.0)
172+
173+
assert_approx_equal(rf.derivative(), 1.0)
174+
175+
a = AdjFloat(3.0)
176+
b = exp(a)
177+
rf = ReducedFunctional(b, Control(a))
178+
assert_approx_equal(rf.derivative(), math.exp(3.0))
179+
180+
a = AdjFloat(2.0)
181+
b = log(a)
182+
rf = ReducedFunctional(b, Control(a))
183+
assert_approx_equal(rf.derivative(), 1./2.)
184+
185+
158186
def test_float_exponentiation():
159187
a = AdjFloat(3.0)
160188
b = AdjFloat(2.0)
@@ -172,7 +200,7 @@ def test_float_exponentiation():
172200
assert rf(AdjFloat(1.0)) == 1.0
173201
assert rf(AdjFloat(2.0)) == 4.0
174202
# d(a**a)/da = dexp(a log(a))/da = a**a * (log(a) + 1)
175-
assert_approx_equal(rf.derivative(), 4.0 * (log(2.0)+1.0))
203+
assert_approx_equal(rf.derivative(), 4.0 * (math.log(2.0)+1.0))
176204

177205
# TODO: __rpow__ is not yet implemented
178206

0 commit comments

Comments
 (0)