예제 #1
0
 def __init__(self, *args, **kwargs):
     # converts lists to matrices and dicts to structs
     if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
         self.entries = args[0]
     elif len(args) == 0:
         self.entries = kwargs
     else:
         raise TaichiSyntaxError(
             "Custom structs need to be initialized using either dictionary or keyword arguments"
         )
     for k, v in self.entries.items():
         if isinstance(v, (list, tuple)):
             v = Matrix(v)
         if isinstance(v, dict):
             v = Struct(v)
         self.entries[k] = v if in_python_scope() else impl.expr_init(v)
     self.register_members()
     self.in_python_scope = in_python_scope()
예제 #2
0
 def fill(self, val):
     """Fills this scalar field with a specified value.
     """
     if in_python_scope():
         from taichi._kernels import fill_tensor  # pylint: disable=C0415
         fill_tensor(self, val)
     else:
         from taichi._funcs import \
             field_fill_taichi_scope  # pylint: disable=C0415
         field_fill_taichi_scope(self, val)
예제 #3
0
파일: struct.py 프로젝트: Leonz5288/taichi
 def cast(self, struct):
     # sanity check members
     if self.members.keys() != struct.entries.keys():
         raise TaichiSyntaxError(
             "Incompatible arguments for custom struct members!")
     entries = {}
     for k, dtype in self.members.items():
         if isinstance(dtype, CompoundType):
             entries[k] = dtype.cast(struct.entries[k])
         else:
             if in_python_scope():
                 v = struct.entries[k]
                 entries[k] = int(v) if dtype in integer_types else float(v)
             else:
                 entries[k] = ops.cast(struct.entries[k], dtype)
     return Struct(entries)
예제 #4
0
파일: struct.py 프로젝트: kokizzu/taichi
 def __setitem__(self, key, value):
     if isinstance(self.entries[key], SNodeHostAccess):
         self.entries[key].accessor.setter(value, *self.entries[key].key)
     else:
         if in_python_scope():
             if isinstance(self.entries[key], Struct) or isinstance(
                     self.entries[key], Matrix):
                 self.entries[key]._set_entries(value)
             else:
                 if isinstance(value, numbers.Number):
                     self.entries[key] = value
                 else:
                     raise TypeError(
                         "A number is expected when assigning struct members"
                     )
         else:
             self.entries[key] = value
예제 #5
0
 def cast(self, struct, in_place=False):
     if not in_place:
         struct = struct.copy()
     # sanity check members
     if self.members.keys() != struct.entries.keys():
         raise TaichiSyntaxError(
             "Incompatible arguments for custom struct members!")
     for k, dtype in self.members.items():
         if isinstance(dtype, CompoundType):
             struct.entries[k] = dtype.cast(struct.entries[k])
         else:
             if in_python_scope():
                 v = struct.entries[k]
                 struct.entries[k] = int(
                     v) if dtype in ti.integer_types else float(v)
             else:
                 struct.entries[k] = cast(struct.entries[k], dtype)
     return struct
예제 #6
0
    def __setattr__(self, attr_name, values):
        if len(attr_name) > 1:
            for key_group in _VectorType._KEYMAP_SET:
                if any(x not in key_group for x in attr_name):
                    continue

                if len(attr_name) != len(values):
                    raise Exception("values does not match the attribute")

                was_valid = False
                for key, value in zip(attr_name, values):
                    if in_python_scope():
                        self[key_group.index(key)] = value
                    else:
                        self(key_group.index(key))._assign(value)
                    was_valid = True

                if was_valid:
                    return

        super().__setattr__(attr_name, values)
예제 #7
0
 def __init__(self, entries):
     assert isinstance(entries, dict)
     self.entries = entries
     self.register_members()
     self.in_python_scope = in_python_scope()
예제 #8
0
 def __ipow__(self, other):
     if in_python_scope():
         return NotImplemented
     self._assign(ops.pow(self, other))
     return self
예제 #9
0
 def __irshift__(self, other):
     if in_python_scope():
         return NotImplemented
     self._assign(ops.bit_shr(self, other))
     return self
예제 #10
0
 def __ifloordiv__(self, other):
     if in_python_scope():
         return NotImplemented
     self._assign(ops.floordiv(self, other))
     return self
예제 #11
0
 def __ior__(self, other):
     if in_python_scope():
         return NotImplemented
     self._atomic_or(other)
     return self
예제 #12
0
파일: matrix.py 프로젝트: quadpixels/taichi
    def __init__(self,
                 n=1,
                 m=1,
                 dt=None,
                 shape=None,
                 offset=None,
                 empty=False,
                 layout=Layout.AOS,
                 needs_grad=False,
                 keep_raw=False,
                 disable_local_tensor=False,
                 rows=None,
                 cols=None):
        self.local_tensor_proxy = None
        self.any_array_access = None
        self.grad = None

        # construct from rows or cols (deprecated)
        if rows is not None or cols is not None:
            warning(
                f"ti.Matrix(rows=[...]) or ti.Matrix(cols=[...]) is deprecated, use ti.Matrix.rows([...]) or ti.Matrix.cols([...]) instead.",
                DeprecationWarning,
                stacklevel=2)
            if rows is not None and cols is not None:
                raise Exception("cannot specify both rows and columns")
            self.dt = dt
            mat = Matrix.cols(cols) if cols is not None else Matrix.rows(rows)
            self.n = mat.n
            self.m = mat.m
            self.entries = mat.entries
            return

        elif empty == True:
            warning(
                f"ti.Matrix(n, m, empty=True) is deprecated, use ti.Matrix.empty(n, m) instead",
                DeprecationWarning,
                stacklevel=2)
            self.dt = dt
            self.entries = [[None] * m for _ in range(n)]
            return

        elif isinstance(n, (list, tuple, np.ndarray)):
            if len(n) == 0:
                mat = []
            elif isinstance(n[0], Matrix):
                raise Exception(
                    'cols/rows required when using list of vectors')
            elif not isinstance(n[0], Iterable):
                if impl.inside_kernel():
                    # wrap potential constants with Expr
                    if keep_raw:
                        mat = [list([x]) for x in n]
                    else:
                        if in_python_scope(
                        ) or disable_local_tensor or not ti.current_cfg(
                        ).dynamic_index:
                            mat = [list([expr.Expr(x)]) for x in n]
                        else:
                            if not ti.is_extension_supported(
                                    ti.cfg.arch, ti.extension.dynamic_index):
                                raise Exception(
                                    'Backend ' + str(ti.cfg.arch) +
                                    ' doesn\'t support dynamic index')
                            if dt is None:
                                if isinstance(n[0], int):
                                    dt = impl.get_runtime().default_ip
                                elif isinstance(n[0], float):
                                    dt = impl.get_runtime().default_fp
                                else:
                                    raise Exception(
                                        'dt required when using dynamic_index for local tensor'
                                    )
                            self.local_tensor_proxy = impl.expr_init_local_tensor(
                                [len(n)], dt,
                                expr.make_expr_group([expr.Expr(x)
                                                      for x in n]))
                            mat = []
                            for i in range(len(n)):
                                mat.append(
                                    list([
                                        ti.local_subscript_with_offset(
                                            self.local_tensor_proxy, (i, ),
                                            (len(n), ))
                                    ]))
                else:
                    mat = [[x] for x in n]
            else:
                if in_python_scope(
                ) or disable_local_tensor or not ti.current_cfg(
                ).dynamic_index:
                    mat = [list(r) for r in n]
                else:
                    if not ti.is_extension_supported(
                            ti.cfg.arch, ti.extension.dynamic_index):
                        raise Exception('Backend ' + str(ti.cfg.arch) +
                                        ' doesn\'t support dynamic index')
                    if dt is None:
                        if isinstance(n[0][0], int):
                            dt = impl.get_runtime().default_ip
                        elif isinstance(n[0][0], float):
                            dt = impl.get_runtime().default_fp
                        else:
                            raise Exception(
                                'dt required when using dynamic_index for local tensor'
                            )
                    self.local_tensor_proxy = impl.expr_init_local_tensor(
                        [len(n), len(n[0])], dt,
                        expr.make_expr_group(
                            [expr.Expr(x) for row in n for x in row]))
                    mat = []
                    for i in range(len(n)):
                        mat.append([])
                        for j in range(len(n[0])):
                            mat[i].append(
                                ti.local_subscript_with_offset(
                                    self.local_tensor_proxy, (i, j),
                                    (len(n), len(n[0]))))
            self.n = len(mat)
            if len(mat) > 0:
                self.m = len(mat[0])
            else:
                self.m = 1
            self.entries = [x for row in mat for x in row]

        else:
            if dt is None:
                # create a local matrix with specific (n, m)
                self.entries = [impl.expr_init(None) for i in range(n * m)]
                self.n = n
                self.m = m
            else:
                # construct global matrix (deprecated)
                warning(
                    "Declaring global matrices using `ti.Matrix(n, m, dt, shape)` is deprecated, "
                    "use `ti.Matrix.field(n, m, dtype, shape)` instead",
                    DeprecationWarning,
                    stacklevel=2)
                mat = Matrix.field(n=n,
                                   m=m,
                                   dtype=dt,
                                   shape=shape,
                                   offset=offset,
                                   needs_grad=needs_grad,
                                   layout=layout)
                self.n = mat.n
                self.m = mat.m
                self.entries = mat.entries
                self.grad = mat.grad

        if self.n * self.m > 32:
            warning(
                f'Taichi matrices/vectors with {self.n}x{self.m} > 32 entries are not suggested.'
                ' Matrices/vectors will be automatically unrolled at compile-time for performance.'
                ' So the compilation time could be extremely long if the matrix size is too big.'
                ' You may use a field to store a large matrix like this, e.g.:\n'
                f'    x = ti.field(ti.f32, ({self.n}, {self.m})).\n'
                ' See https://taichi.readthedocs.io/en/stable/tensor_matrix.html#matrix-size'
                ' for more details.',
                UserWarning,
                stacklevel=2)