Skip to content

Commit 5f46e16

Browse files
authored
Merge pull request #177 from dolfin-adjoint/JHopeCollins/set_working_tape_decorator
Allow using `set_working_tape` as a function decorator
2 parents a8ee848 + fb3b5bd commit 5f46e16

File tree

1 file changed

+38
-19
lines changed

1 file changed

+38
-19
lines changed

pyadjoint/tape.py

+38-19
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import os
33
import re
44
import threading
5-
from contextlib import contextmanager
6-
from functools import wraps
5+
from contextlib import contextmanager, ContextDecorator
76
from itertools import chain
87
from typing import Optional, Iterable
98
from abc import ABC, abstractmethod
@@ -28,9 +27,13 @@ def continue_annotation():
2827
return _annotation_enabled
2928

3029

31-
class set_working_tape(object):
32-
"""A context manager whithin which a new tape is set as the working tape.
33-
This context manager can also be used in an imperative manner.
30+
class set_working_tape(ContextDecorator):
31+
"""Set a new tape as the working tape.
32+
33+
This class can be used in three ways:
34+
1) as a free function to replace the working tape,
35+
2) as a context manager within which a new tape is set as the working tape,
36+
3) as a function decorator so that the new tape is set only inside the function.
3437
3538
Example usage:
3639
@@ -48,6 +51,23 @@ class set_working_tape(object):
4851
4952
with set_working_tape() as tape:
5053
...
54+
55+
3) Set the local tape inside a decorated function.
56+
The two functions below are equivalent:
57+
58+
.. highlight:: python
59+
.. code-block:: python
60+
61+
@set_working_tape()
62+
def decorated_function(*args, **kwargs):
63+
# do something here
64+
return ReducedFunctional(functional, control)
65+
66+
def context_function(*args, **kwargs):
67+
with set_working_tape():
68+
# do something here
69+
return ReducedFunctional(functional, control)
70+
5171
"""
5272

5373
def __init__(self, tape=None, **tape_kwargs):
@@ -68,8 +88,8 @@ def __exit__(self, *args):
6888
_working_tape = self.old_tape
6989

7090

71-
class stop_annotating(object):
72-
"""A context manager within which annotation is stopped.
91+
class stop_annotating(ContextDecorator):
92+
"""A context manager and function decorator within which annotation is stopped.
7393
7494
Args:
7595
modifies (OverloadedType or list[OverloadedType]): One or more
@@ -82,17 +102,23 @@ class stop_annotating(object):
82102
modified variables at the end of the context manager. """
83103

84104
def __init__(self, modifies=None):
85-
global _annotation_enabled
86105
self.modifies = modifies
87-
self._orig_annotation_enabled = _annotation_enabled
106+
# the `no_annotations` context manager could be nested,
107+
# so we need a stack to keep track of the original states.
108+
self._orig_annotation_enabled = []
88109

89110
def __enter__(self):
90111
global _annotation_enabled
112+
if self.modifies and len(self._orig_annotation_enabled) != 0:
113+
raise ValueError(
114+
"Cannot use `modifies` argument if `stop_annotating` is nested,"
115+
" e.g. if used as the `no_annotations` decorator.")
116+
self._orig_annotation_enabled.append(_annotation_enabled)
91117
_annotation_enabled = False
92118

93119
def __exit__(self, *args):
94120
global _annotation_enabled
95-
_annotation_enabled = self._orig_annotation_enabled
121+
_annotation_enabled = self._orig_annotation_enabled.pop()
96122
if self.modifies is not None:
97123
try:
98124
self.modifies.create_block_variable()
@@ -101,15 +127,8 @@ def __exit__(self, *args):
101127
var.create_block_variable()
102128

103129

104-
def no_annotations(function):
105-
"""Decorator to turn off annotation for the decorated function."""
106-
107-
@wraps(function)
108-
def wrapper(*args, **kwargs):
109-
with stop_annotating():
110-
return function(*args, **kwargs)
111-
112-
return wrapper
130+
no_annotations = stop_annotating()
131+
"""Decorator to turn off annotation for the decorated function."""
113132

114133

115134
def annotate_tape(kwargs=None):

0 commit comments

Comments
 (0)