diff --git a/devito/data/data.py b/devito/data/data.py index 7da3b0eaf8..0a4d671e93 100644 --- a/devito/data/data.py +++ b/devito/data/data.py @@ -208,6 +208,57 @@ def __repr__(self): def __str__(self): return super(Data, self._local).__str__() + def transpose(self, *axes): + """ + Return a view of ``self`` with permuted axes. + + Overridden so that ``_decomposition``, ``_modulo`` (and the convenience + flag ``_is_distributed``) are permuted to match the new axis ordering, + rather than copied verbatim from ``self`` as ``__array_finalize__`` + would otherwise leave them. Without this, a subsequent slice on the + transposed view (e.g. ``f.data.T[::2, ::2]``) is computed against the + wrong per-axis decomposition and silently returns a wrong-shaped + result (see issue #2187). + """ + # Accept the same axis-spec forms as ``numpy.ndarray.transpose``: + # no args, a single ``None``, a single tuple/list, or per-arg. + if len(axes) == 1: + axes = as_tuple(axes[0]) + new_order = ( + tuple(range(self.ndim - 1, -1, -1)) if not axes + else tuple(ax % self.ndim for ax in axes) + ) + + ret = super().transpose(*axes) + ret._decomposition = tuple(self._decomposition[i] for i in new_order) + ret._modulo = tuple(self._modulo[i] for i in new_order) + ret._is_distributed = any(d is not None for d in ret._decomposition) + return ret + + def swapaxes(self, axis1, axis2): + """ + Return a view of ``self`` with ``axis1`` and ``axis2`` swapped, with + ``_decomposition`` / ``_modulo`` swapped in the same way (see + :meth:`transpose`). + """ + axis1 = axis1 % self.ndim + axis2 = axis2 % self.ndim + ret = super().swapaxes(axis1, axis2) + order = list(range(self.ndim)) + order[axis1], order[axis2] = order[axis2], order[axis1] + ret._decomposition = tuple(self._decomposition[i] for i in order) + ret._modulo = tuple(self._modulo[i] for i in order) + ret._is_distributed = any(d is not None for d in ret._decomposition) + return ret + + @property + def T(self): + """ + The transposed array. Overridden so the C-level ``ndarray.T`` shortcut + also permutes the per-axis metadata (see :meth:`transpose`). + """ + return self.transpose() + @_check_idx def __getitem__(self, glb_idx, comm_type, gather_rank=None): loc_idx = self._index_glb_to_loc(glb_idx) diff --git a/tests/test_data.py b/tests/test_data.py index f5d0e8d177..fb09d4ef47 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -211,6 +211,78 @@ def test_indexing_into_sparse(self): sf.data[1:-1, 0] = np.arange(8) assert np.all(sf.data[1:-1, 0] == np.arange(8)) + def test_slice_after_transpose(self): + """ + Slicing a ``Data`` view that has been transposed (via ``.T``, + ``transpose`` or ``swapaxes``) must use the new axis ordering for + per-axis metadata. Previously the metadata was copied verbatim from + the un-transposed array, so a subsequent slice was computed against + the wrong decomposition and silently returned a wrong-shaped result + (see issue #2187). + """ + grid = Grid(shape=(4, 6)) + f = Function(name='f', grid=grid) + f.data[:] = np.arange(24).reshape((4, 6)).astype(np.float32) + ref = np.array(f.data) + + # ``.T`` (C-level shortcut) then slice + assert np.array_equal(f.data.T[::2, ::2], ref.T[::2, ::2]) + + # Equivalent: slice then ``.T`` + assert np.array_equal(f.data[::2, ::2].T, ref[::2, ::2].T) + + # Explicit ``transpose`` call -- same behavior as ``.T`` + assert np.array_equal(f.data.transpose()[::2, ::2], + ref.transpose()[::2, ::2]) + + # ``swapaxes`` between non-conforming dims + assert np.array_equal(f.data.swapaxes(0, 1)[::2, ::2], + ref.swapaxes(0, 1)[::2, ::2]) + + # 3D transpose with an explicit axis order, then per-axis slice + grid3 = Grid(shape=(2, 4, 6)) + g = Function(name='g3', grid=grid3) + g.data[:] = np.arange(48).reshape((2, 4, 6)).astype(np.float32) + ref3 = np.array(g.data) + + assert np.array_equal(g.data.T[::2, ::2, ::2], ref3.T[::2, ::2, ::2]) + assert np.array_equal(g.data.transpose((1, 0, 2))[::2, ::1, ::3], + ref3.transpose((1, 0, 2))[::2, ::1, ::3]) + + def test_transpose_permutes_data_metadata(self): + """ + After a transpose-like operation, ``_decomposition`` and ``_modulo`` + must be permuted to match the new axis order so that subsequent + ``__getitem__`` translations use the right per-axis ranges. + """ + grid = Grid(shape=(4, 6)) + f = Function(name='f', grid=grid) + + original_decomp = f.data._decomposition + assert len(original_decomp) == 2 + + # ``.T`` reverses everything + tdata = f.data.T + assert tdata._decomposition == original_decomp[::-1] + assert tdata._modulo == f.data._modulo[::-1] + + # ``transpose()`` with no args == ``.T`` + tdata2 = f.data.transpose() + assert tdata2._decomposition == original_decomp[::-1] + + # ``swapaxes`` swaps the two named axes + sdata = f.data.swapaxes(0, 1) + assert sdata._decomposition == (original_decomp[1], original_decomp[0]) + + # Explicit axis-order + grid3 = Grid(shape=(2, 4, 6)) + g = Function(name='g3', grid=grid3) + gdec = g.data._decomposition + perm = g.data.transpose((1, 2, 0)) + assert perm._decomposition == (gdec[1], gdec[2], gdec[0]) + assert perm._modulo == (g.data._modulo[1], g.data._modulo[2], + g.data._modulo[0]) + @pytest.mark.parallel(mode=1) def test_indexing_into_sparse_subfunc_singlempi(self, mode): grid = Grid(shape=(4, 4))