|
| 1 | +from functools import wraps |
1 | 2 | from .block import Block
|
2 | 3 | from .overloaded_type import OverloadedType, register_overloaded_type, create_overloaded_object
|
3 | 4 | from .tape import get_working_tape, annotate_tape, stop_annotating
|
| 5 | +import math |
4 | 6 |
|
5 | 7 |
|
6 | 8 | def annotate_operator(operator):
|
@@ -129,6 +131,98 @@ def _ad_str(self):
|
129 | 131 | return str(self.block_variable.saved_output)
|
130 | 132 |
|
131 | 133 |
|
| 134 | +_exp = math.exp |
| 135 | +_log = math.log |
| 136 | + |
| 137 | + |
| 138 | +@wraps(_exp) |
| 139 | +def exp(a, **kwargs): |
| 140 | + annotate = annotate_tape(kwargs) |
| 141 | + if annotate: |
| 142 | + a = create_overloaded_object(a) |
| 143 | + |
| 144 | + block = ExpBlock(a) |
| 145 | + tape = get_working_tape() |
| 146 | + tape.add_block(block) |
| 147 | + |
| 148 | + with stop_annotating(): |
| 149 | + out = _exp(a) |
| 150 | + out = AdjFloat(out) |
| 151 | + |
| 152 | + if annotate: |
| 153 | + block.add_output(out.block_variable) |
| 154 | + return out |
| 155 | + |
| 156 | + |
| 157 | +def log(a, **kwargs): |
| 158 | + """Return the natural logarithm of a.""" |
| 159 | + annotate = annotate_tape(kwargs) |
| 160 | + if annotate: |
| 161 | + a = create_overloaded_object(a) |
| 162 | + |
| 163 | + block = LogBlock(a) |
| 164 | + tape = get_working_tape() |
| 165 | + tape.add_block(block) |
| 166 | + |
| 167 | + with stop_annotating(): |
| 168 | + out = _log(a) |
| 169 | + out = AdjFloat(out) |
| 170 | + |
| 171 | + if annotate: |
| 172 | + block.add_output(out.block_variable) |
| 173 | + return out |
| 174 | + |
| 175 | + |
| 176 | +class ExpBlock(Block): |
| 177 | + def __init__(self, a): |
| 178 | + super().__init__() |
| 179 | + self.add_dependency(a) |
| 180 | + |
| 181 | + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): |
| 182 | + adj_input = adj_inputs[0] |
| 183 | + input0 = inputs[0] |
| 184 | + return _exp(input0) * adj_input |
| 185 | + |
| 186 | + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): |
| 187 | + tlm_input = tlm_inputs[0] |
| 188 | + input0 = inputs[0] |
| 189 | + return _exp(input0) * tlm_input |
| 190 | + |
| 191 | + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, |
| 192 | + relevant_dependencies, prepared=None): |
| 193 | + input0 = inputs[0] |
| 194 | + hessian = hessian_inputs[0] |
| 195 | + return _exp(input0) * hessian |
| 196 | + |
| 197 | + def recompute_component(self, inputs, block_variable, idx, prepared): |
| 198 | + return _exp(inputs[0]) |
| 199 | + |
| 200 | + |
| 201 | +class LogBlock(Block): |
| 202 | + def __init__(self, a): |
| 203 | + super().__init__() |
| 204 | + self.add_dependency(a) |
| 205 | + |
| 206 | + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): |
| 207 | + adj_input = adj_inputs[0] |
| 208 | + input0 = inputs[0] |
| 209 | + return adj_input / input0 |
| 210 | + |
| 211 | + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): |
| 212 | + tlm_input = tlm_inputs[0] |
| 213 | + input0 = inputs[0] |
| 214 | + return tlm_input / input0 |
| 215 | + |
| 216 | + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, |
| 217 | + relevant_dependencies, prepared=None): |
| 218 | + input0 = inputs[0] |
| 219 | + hessian = hessian_inputs[0] |
| 220 | + return -hessian / input0 / input0 |
| 221 | + |
| 222 | + def recompute_component(self, inputs, block_variable, idx, prepared): |
| 223 | + return _log(inputs[0]) |
| 224 | + |
| 225 | + |
132 | 226 | _min = min
|
133 | 227 | _max = max
|
134 | 228 |
|
|
0 commit comments