def type_inference(self): assert self.indices.shape[-1] <= self.data.rank expected_updates_shape = (self.indices.shape[:-1] + self.data.shape[self.indices.shape[-1]:]) assert is_compatible_symbolic_vector(self.updates.shape, tuple(expected_updates_shape)) return self.data.sym_type
def type_inference(self): if self.indices.shape[-1] > self.data.rank: raise AssertionError expected_updates_shape = ( self.indices.shape[:-1] + self.data.shape[self.indices.shape[-1] :] ) if not is_compatible_symbolic_vector( self.updates.shape, tuple(expected_updates_shape) ): raise AssertionError return self.data.sym_type
def test_cast_with_symbolic_value(self): input_shape = [get_new_symbol(), 1] input_placeholders = { "x": mb.placeholder(shape=input_shape), } def build(x): shape = mb.shape(x=x) return mb.cast(x=shape, dtype="int32") with Function(input_placeholders) as ssa_func: output_vars = build(**ssa_func.inputs) assert is_compatible_symbolic_vector(output_vars.sym_val, [get_new_symbol(), 1])
def type_inference(self): if self.axis.val < -self.data.rank or self.axis.val >= self.data.rank: raise IndexError( "Axis value {} is out of bounds for {} node {}".format( self.axis.val, self.op_type, self.name)) axis = self.axis.val axis = axis if axis >= 0 else axis + self.data.rank assert is_compatible_symbolic_vector(self.indices.shape, self.updates.shape) assert self.data.rank == self.indices.rank for i in range(self.data.rank): if i != axis: assert self.data.shape[i] == self.indices.shape[i] return self.data.sym_type
def type_inference(self): if self.axis.val < -self.data.rank or self.axis.val >= self.data.rank: raise IndexError( "Axis value {} is out of bounds for {} node {}".format( self.axis.val, self.op_type, self.name)) axis = self.axis.val axis = axis if axis >= 0 else axis + self.data.rank expected_updates_shape = (self.data.shape[:axis] + self.indices.shape + self.data.shape[axis + 1:]) err = "Updates shape {} is incorrect. It should be {}.".format( self.updates.shape, expected_updates_shape) assert is_compatible_symbolic_vector( self.updates.shape, tuple(expected_updates_shape)), err return self.data.sym_type
def type_inference(self): # Verify the updates and the data slicing have the same shape begin = self.begin.val end = self.end.val data_rank = self.data.rank stride = self.stride.val if self.stride is not None else [1] * data_rank begin_mask = ( self.begin_mask.val if self.begin_mask is not None else [False] * data_rank ) end_mask = self.end_mask.val if self.end_mask is not None else [False] * data_rank squeeze_mask = ( self.squeeze_mask.val if self.squeeze_mask is not None else [False] * data_rank ) data_shape = self.data.shape expected_updates_shape = tuple(_solve_slice_by_index_shape(data_shape, begin, end, stride, begin_mask, end_mask, squeeze_mask)) if not is_compatible_symbolic_vector(expected_updates_shape, self.updates.shape): raise ValueError("The updates tensor should have shape {}. Got {}".format(expected_updates_shape, self.updates.shape)) return self.data.sym_type
def type_inference(self): num_tensors = len(self.values) if num_tensors == 0: raise ValueError("Cannot stack 0 tensor") # get the first value without symbolic shape t_shape = None for value in self.values: if not any_symbolic(value.shape): t_shape = value.shape break t_shape = self.values[0].shape if t_shape is None else t_shape # compare all shape for t in self.values: if not is_compatible_symbolic_vector(t.shape, t_shape): msg = "Component tensor {} has shape {}, others have {}" raise ValueError(msg.format(t.name, t.shape, t_shape)) ret_shape = list(t_shape) ret_shape.insert(self.axis.val, num_tensors) return types.tensor(self.values[0].dtype, ret_shape)