Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions devito/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,57 @@ def __repr__(self):
def __str__(self):
return super(Data, self._local).__str__()

def transpose(self, *axes):
"""
Return a view of ``self`` with permuted axes.

Overridden so that ``_decomposition``, ``_modulo`` (and the convenience
flag ``_is_distributed``) are permuted to match the new axis ordering,
rather than copied verbatim from ``self`` as ``__array_finalize__``
would otherwise leave them. Without this, a subsequent slice on the
transposed view (e.g. ``f.data.T[::2, ::2]``) is computed against the
wrong per-axis decomposition and silently returns a wrong-shaped
result (see issue #2187).
"""
# Accept the same axis-spec forms as ``numpy.ndarray.transpose``:
# no args, a single ``None``, a single tuple/list, or per-arg.
if len(axes) == 1:
axes = as_tuple(axes[0])
new_order = (
tuple(range(self.ndim - 1, -1, -1)) if not axes
else tuple(ax % self.ndim for ax in axes)
)

ret = super().transpose(*axes)
ret._decomposition = tuple(self._decomposition[i] for i in new_order)
ret._modulo = tuple(self._modulo[i] for i in new_order)
ret._is_distributed = any(d is not None for d in ret._decomposition)
return ret

def swapaxes(self, axis1, axis2):
"""
Return a view of ``self`` with ``axis1`` and ``axis2`` swapped, with
``_decomposition`` / ``_modulo`` swapped in the same way (see
:meth:`transpose`).
"""
axis1 = axis1 % self.ndim
axis2 = axis2 % self.ndim
ret = super().swapaxes(axis1, axis2)
order = list(range(self.ndim))
order[axis1], order[axis2] = order[axis2], order[axis1]
ret._decomposition = tuple(self._decomposition[i] for i in order)
ret._modulo = tuple(self._modulo[i] for i in order)
ret._is_distributed = any(d is not None for d in ret._decomposition)
return ret

@property
def T(self):
"""
The transposed array. Overridden so the C-level ``ndarray.T`` shortcut
also permutes the per-axis metadata (see :meth:`transpose`).
"""
return self.transpose()

@_check_idx
def __getitem__(self, glb_idx, comm_type, gather_rank=None):
loc_idx = self._index_glb_to_loc(glb_idx)
Expand Down
72 changes: 72 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,78 @@ def test_indexing_into_sparse(self):
sf.data[1:-1, 0] = np.arange(8)
assert np.all(sf.data[1:-1, 0] == np.arange(8))

def test_slice_after_transpose(self):
"""
Slicing a ``Data`` view that has been transposed (via ``.T``,
``transpose`` or ``swapaxes``) must use the new axis ordering for
per-axis metadata. Previously the metadata was copied verbatim from
the un-transposed array, so a subsequent slice was computed against
the wrong decomposition and silently returned a wrong-shaped result
(see issue #2187).
"""
grid = Grid(shape=(4, 6))
f = Function(name='f', grid=grid)
f.data[:] = np.arange(24).reshape((4, 6)).astype(np.float32)
ref = np.array(f.data)

# ``.T`` (C-level shortcut) then slice
assert np.array_equal(f.data.T[::2, ::2], ref.T[::2, ::2])

# Equivalent: slice then ``.T``
assert np.array_equal(f.data[::2, ::2].T, ref[::2, ::2].T)

# Explicit ``transpose`` call -- same behavior as ``.T``
assert np.array_equal(f.data.transpose()[::2, ::2],
ref.transpose()[::2, ::2])

# ``swapaxes`` between non-conforming dims
assert np.array_equal(f.data.swapaxes(0, 1)[::2, ::2],
ref.swapaxes(0, 1)[::2, ::2])

# 3D transpose with an explicit axis order, then per-axis slice
grid3 = Grid(shape=(2, 4, 6))
g = Function(name='g3', grid=grid3)
g.data[:] = np.arange(48).reshape((2, 4, 6)).astype(np.float32)
ref3 = np.array(g.data)

assert np.array_equal(g.data.T[::2, ::2, ::2], ref3.T[::2, ::2, ::2])
assert np.array_equal(g.data.transpose((1, 0, 2))[::2, ::1, ::3],
ref3.transpose((1, 0, 2))[::2, ::1, ::3])

def test_transpose_permutes_data_metadata(self):
"""
After a transpose-like operation, ``_decomposition`` and ``_modulo``
must be permuted to match the new axis order so that subsequent
``__getitem__`` translations use the right per-axis ranges.
"""
grid = Grid(shape=(4, 6))
f = Function(name='f', grid=grid)

original_decomp = f.data._decomposition
assert len(original_decomp) == 2

# ``.T`` reverses everything
tdata = f.data.T
assert tdata._decomposition == original_decomp[::-1]
assert tdata._modulo == f.data._modulo[::-1]

# ``transpose()`` with no args == ``.T``
tdata2 = f.data.transpose()
assert tdata2._decomposition == original_decomp[::-1]

# ``swapaxes`` swaps the two named axes
sdata = f.data.swapaxes(0, 1)
assert sdata._decomposition == (original_decomp[1], original_decomp[0])

# Explicit axis-order
grid3 = Grid(shape=(2, 4, 6))
g = Function(name='g3', grid=grid3)
gdec = g.data._decomposition
perm = g.data.transpose((1, 2, 0))
assert perm._decomposition == (gdec[1], gdec[2], gdec[0])
assert perm._modulo == (g.data._modulo[1], g.data._modulo[2],
g.data._modulo[0])

@pytest.mark.parallel(mode=1)
def test_indexing_into_sparse_subfunc_singlempi(self, mode):
grid = Grid(shape=(4, 4))
Expand Down
Loading