Exemplo n.º 1
0
 def type_inference(self):
     assert self.indices.shape[-1] <= self.data.rank
     expected_updates_shape = (self.indices.shape[:-1] +
                               self.data.shape[self.indices.shape[-1]:])
     assert is_compatible_symbolic_vector(self.updates.shape,
                                          tuple(expected_updates_shape))
     return self.data.sym_type
Exemplo n.º 2
0
 def type_inference(self):
     if self.indices.shape[-1] > self.data.rank:
         raise AssertionError
     expected_updates_shape = (
         self.indices.shape[:-1] + self.data.shape[self.indices.shape[-1] :]
     )
     if not is_compatible_symbolic_vector(
         self.updates.shape, tuple(expected_updates_shape)
     ):
         raise AssertionError
     return self.data.sym_type
Exemplo n.º 3
0
    def test_cast_with_symbolic_value(self):
        input_shape = [get_new_symbol(), 1]
        input_placeholders = {
            "x": mb.placeholder(shape=input_shape),
        }

        def build(x):
            shape = mb.shape(x=x)
            return mb.cast(x=shape, dtype="int32")

        with Function(input_placeholders) as ssa_func:
            output_vars = build(**ssa_func.inputs)
            assert is_compatible_symbolic_vector(output_vars.sym_val,
                                                 [get_new_symbol(), 1])
Exemplo n.º 4
0
    def type_inference(self):
        if self.axis.val < -self.data.rank or self.axis.val >= self.data.rank:
            raise IndexError(
                "Axis value {} is out of bounds for {} node {}".format(
                    self.axis.val, self.op_type, self.name))

        axis = self.axis.val
        axis = axis if axis >= 0 else axis + self.data.rank

        assert is_compatible_symbolic_vector(self.indices.shape,
                                             self.updates.shape)
        assert self.data.rank == self.indices.rank
        for i in range(self.data.rank):
            if i != axis:
                assert self.data.shape[i] == self.indices.shape[i]

        return self.data.sym_type
Exemplo n.º 5
0
    def type_inference(self):
        if self.axis.val < -self.data.rank or self.axis.val >= self.data.rank:
            raise IndexError(
                "Axis value {} is out of bounds for {} node {}".format(
                    self.axis.val, self.op_type, self.name))

        axis = self.axis.val
        axis = axis if axis >= 0 else axis + self.data.rank
        expected_updates_shape = (self.data.shape[:axis] + self.indices.shape +
                                  self.data.shape[axis + 1:])

        err = "Updates shape {} is incorrect. It should be {}.".format(
            self.updates.shape, expected_updates_shape)
        assert is_compatible_symbolic_vector(
            self.updates.shape, tuple(expected_updates_shape)), err

        return self.data.sym_type
Exemplo n.º 6
0
 def type_inference(self):
     # Verify the updates and the data slicing have the same shape
     begin = self.begin.val
     end = self.end.val
     data_rank = self.data.rank
     stride = self.stride.val if self.stride is not None else [1] * data_rank
     begin_mask = (
         self.begin_mask.val if self.begin_mask is not None else [False] * data_rank
     )
     end_mask = self.end_mask.val if self.end_mask is not None else [False] * data_rank
     squeeze_mask = (
         self.squeeze_mask.val if self.squeeze_mask is not None else [False] * data_rank
     )
     data_shape = self.data.shape
     expected_updates_shape = tuple(_solve_slice_by_index_shape(data_shape, begin, end, stride, begin_mask, end_mask, squeeze_mask))
     if not is_compatible_symbolic_vector(expected_updates_shape, self.updates.shape):
         raise ValueError("The updates tensor should have shape {}. Got {}".format(expected_updates_shape, self.updates.shape))
     return self.data.sym_type
Exemplo n.º 7
0
    def type_inference(self):

        num_tensors = len(self.values)
        if num_tensors == 0:
            raise ValueError("Cannot stack 0 tensor")

        # get the first value without symbolic shape
        t_shape = None
        for value in self.values:
            if not any_symbolic(value.shape):
                t_shape = value.shape
                break
        t_shape = self.values[0].shape if t_shape is None else t_shape

        # compare all shape
        for t in self.values:
            if not is_compatible_symbolic_vector(t.shape, t_shape):
                msg = "Component tensor {} has shape {}, others have {}"
                raise ValueError(msg.format(t.name, t.shape, t_shape))
        ret_shape = list(t_shape)
        ret_shape.insert(self.axis.val, num_tensors)
        return types.tensor(self.values[0].dtype, ret_shape)