예제 #1
0
파일: test_ops.py 프로젝트: gyc567/devito
    def test_create_ops_arg_function(self, read):

        u = OpsAccessible('u', np.float32, read)
        dat = OpsDat('u_dat')
        stencil = OpsStencil('stencil')
        info = AccessibleInfo(u, None, None)

        res = create_ops_arg(u, {'u': info}, {'u': dat}, {u: stencil})

        assert type(res) == namespace['ops_arg_dat']
        assert res.args == (dat, 1, stencil,
                            Literal('"%s"' % dtype_to_cstr(u.dtype)),
                            namespace['ops_read']
                            if read else namespace['ops_write'])
예제 #2
0
    def test_create_ops_arg_function(self, read):

        u = OpsAccessible('u', dtype=np.float32, read_only=read)
        dat = OpsDat('u_dat')
        stencil = OpsStencil('stencil')
        info = AccessibleInfo(u, None, None)

        ops_arg = create_ops_arg(u, {'u': info}, {'u': dat}, {u: stencil})

        assert ops_arg.ops_type == namespace['ops_arg_dat']
        assert ops_arg.ops_name == OpsDat('u_dat')
        assert ops_arg.elements_per_point == 1
        assert ops_arg.dtype == Literal('"%s"' % dtype_to_cstr(u.dtype))
        assert ops_arg.rw_flag == \
            namespace['ops_read'] if read else namespace['ops_write']
예제 #3
0
    def new_ops_arg(self, indexed, is_write):
        """
        Create an Indexed node using OPS representation.

        Parameters
        ----------
        indexed : Indexed
            Indexed object using devito representation.

        Returns
        -------
        Indexed
            Indexed node using OPS representation.
        """

        # Build the OPS arg identifier
        time_index = split_affine(indexed.indices[TimeFunction._time_position])
        ops_arg_id = ('%s%s%s' % (indexed.name, time_index.var, time_index.shift)
                      if indexed.function.is_TimeFunction else indexed.name)

        if ops_arg_id not in self.ops_args:
            symbol_to_access = OpsAccessible(
                ops_arg_id,
                dtype=indexed.dtype,
                read_only=not is_write
            )

            accessible_info = AccessibleInfo(
                symbol_to_access,
                time_index.var if indexed.function.is_TimeFunction else None,
                time_index.shift if indexed.function.is_TimeFunction else None,
                indexed.function.name)

            self.ops_args[ops_arg_id] = accessible_info
            self.ops_params.append(symbol_to_access)
        else:
            symbol_to_access = self.ops_args[ops_arg_id].accessible

        # Get the space indices
        space_indices = [
            split_affine(i).shift for i in indexed.indices
            if isinstance(split_affine(i).var, SpaceDimension)
        ]

        if space_indices not in self.ops_args_accesses[symbol_to_access]:
            self.ops_args_accesses[symbol_to_access].append(space_indices)

        return OpsAccess(symbol_to_access, space_indices)