Skip to content

Commit c080573

Browse files
jorgensdmscroggs
andauthored
Fix number of constants. (#665)
* Fix number of constants. Resolves: FEniCS/dolfinx#3040 * Add test for constant renumbering (fails onmain) * Apply suggestions from code review Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com> --------- Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com>
1 parent e0a61fd commit c080573

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

ffcx/ir/representation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def _compute_expression_ir(expression, index, prefix, analysis, options, visuali
680680
# Build offsets for Constants
681681
original_constant_offsets = {}
682682
_offset = 0
683-
for constant in ufl.algorithms.analysis.extract_constants(expression):
683+
for constant in ufl.algorithms.analysis.extract_constants(original_expression):
684684
original_constant_offsets[constant] = _offset
685685
_offset += np.prod(constant.ufl_shape, dtype=int)
686686

test/test_jit_expression.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# -*- coding: utf-8 -*-
2-
# Copyright (C) 2019-2022 Michal Habera and Jørgen S. Dokken
2+
# Copyright (C) 2019-2024 Michal Habera and Jørgen S. Dokken
33
#
44
# This file is part of FFCx.(https://www.fenicsproject.org)
55
#
66
# SPDX-License-Identifier: LGPL-3.0-or-later
77

88
import cffi
99
import numpy as np
10+
import pytest
1011

1112
import basix
1213
import basix.ufl
@@ -208,3 +209,45 @@ def exact_expr(x):
208209
exact = exact_expr(points.T)
209210

210211
assert np.allclose(exact, output)
212+
213+
214+
def test_grad_constant(compile_args):
215+
"""Test if numbering of constants are correct after UFL eliminates the constant inside the gradient."""
216+
c_el = basix.ufl.element("Lagrange", "triangle", 1, shape=(2, ))
217+
mesh = ufl.Mesh(c_el)
218+
219+
x = ufl.SpatialCoordinate(mesh)
220+
first_constant = ufl.Constant(mesh)
221+
second_constant = ufl.Constant(mesh)
222+
expr = second_constant * ufl.Dx(x[0]**2 + first_constant, 0)
223+
224+
dtype = np.float64
225+
points = np.array([[0.33, 0.25]], dtype=dtype)
226+
227+
obj, _, _ = ffcx.codegeneration.jit.compile_expressions(
228+
[(expr, points)], cffi_extra_compile_args=compile_args)
229+
230+
ffi = cffi.FFI()
231+
expression = obj[0]
232+
233+
c_type = "double"
234+
c_xtype = "double"
235+
236+
output = np.zeros(1, dtype=dtype)
237+
238+
# Define constants
239+
coords = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=dtype)
240+
u_coeffs = np.array([], dtype=dtype)
241+
consts = np.array([3, 7], dtype=dtype)
242+
entity_index = np.array([0], dtype=np.intc)
243+
quad_perm = np.array([0], dtype=np.dtype("uint8"))
244+
245+
expression.tabulate_tensor_float64(
246+
ffi.cast(f'{c_type} *', output.ctypes.data),
247+
ffi.cast(f'{c_type} *', u_coeffs.ctypes.data),
248+
ffi.cast(f'{c_type} *', consts.ctypes.data),
249+
ffi.cast(f'{c_xtype} *', coords.ctypes.data),
250+
ffi.cast('int *', entity_index.ctypes.data),
251+
ffi.cast('uint8_t *', quad_perm.ctypes.data))
252+
253+
assert output[0] == pytest.approx(consts[1] * 2 * points[0, 0])

0 commit comments

Comments
 (0)