def type_inference(self): x_type = self.x.dtype x_shape = list(self.x.shape) y_shape = list(self.y.shape) x_rank = len(x_shape) if x_rank == 1 and self.transpose_x.val: msg = "Op {} (matmul): x is rank 1, but transpose_x is True, which is not allowed." raise ValueError(msg.format(self.name)) if self.transpose_x.val: x_shape = list(x_shape) x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1] x_shape = tuple(x_shape) if self.transpose_y.val: y_shape = list(y_shape) y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1] y_shape = tuple(y_shape) if not (x_shape[-1] == y_shape[-2] or is_symbolic(x_shape[-1]) or is_symbolic(y_shape[-2])): msg = "Op {} (matmul): x {}, y {} are not broadcastable" raise ValueError(msg.format(self.name, self.x.shape, self.y.shape)) if x_rank == 1: # promote shape of x to rank 2 x_shape = list((1, ) + tuple(x_shape)) ret_shape = list(broadcast_shapes(x_shape[:-2], y_shape[:-2])) ret_shape += [x_shape[-2], y_shape[-1]] if x_rank == 1: # remove the first dimension of the returned shape return types.tensor(x_type, tuple(ret_shape[1:])) else: return types.tensor(x_type, tuple(ret_shape))
def type_inference(self): if self.begin.rank != 1: raise ValueError( "begin should be 1-D tensor, got {}-D tensor instead".format( self.begin.rank)) if self.size.rank != 1: raise ValueError( "size should be 1-D tensor, got {}-D tensor instead".format( self.size.rank)) if self.x.rank != self.begin.shape[0]: raise ValueError( "Length of begin {} doesn't equal to input rank {}.".format( len(self.begin.shape[0]), len(self.x.rank))) if self.x.rank != self.size.shape[0]: raise ValueError( "Length of size {} doesn't equal to input rank {}.".format( len(self.size.shape[0]), len(self.x.rank))) x_shape = self.x.shape ret_shape = [] if self.size.sym_val is None: ret_shape = [get_new_symbol() for _ in range(self.x.rank)] return types.tensor(self.x.dtype, tuple(ret_shape)) for idx, s in enumerate(self.size.sym_val): if is_symbolic(s): ret_shape.append(s) elif s != -1: ret_shape.append(s) elif self.begin.sym_val is not None: ret_shape.append(x_shape[idx] - self.begin.sym_val[idx]) else: ret_shape.append(get_new_symbol()) return types.tensor(self.x.dtype, tuple(ret_shape))
def parse_from_attr(self): if "value" in self.attr: self.datatype = self.attr["value"].__class__ elif "_output_shapes" in self.attr: output_shapes = self.attr["_output_shapes"] if output_shapes[0] is not None and len(output_shapes[0]) > 0: if "dtype" in self.attr: rettype = types.tensor(self.attr["dtype"], tuple(output_shapes[0])) elif "T" in self.attr: rettype = types.tensor(self.attr["T"], tuple(output_shapes[0])) elif "Tparams" in self.attr: rettype = types.tensor( self.attr["Tparams"], tuple(output_shapes[0]) ) else: raise NotImplementedError( "Op-(%s) %s not implemented\nWith attribute:" + str(self.attr) % (self.op, self.name) ) self.datatype = rettype elif "dtype" in self.attr: self.datatype = self.attr["dtype"] elif "shape" in self.attr: shape = self.attr["shape"] assert "dtype" in self.attr if len(shape) == 0: self.datatype = self.attr["dtype"] else: self.datatype = types.tensor(self.attr["dtype"], shape) elif "dtype" in self.attr: self.datatype = self.attr["dtype"]
def type_inference(self): x_type = self.x.dtype x_shape = np.array(self.x.shape) reps = self.reps.sym_val if reps is None: out_shape = tuple([get_new_symbol() for _ in range(self.x.rank)]) return types.tensor(x_type, out_shape) if len(reps) == 0 or len(reps) > self.x.rank: msg = ("Length of the reps ({}) must be at least 1, and " "not greater than the rank of the input x ({})") raise ValueError(msg.format(len(reps), self.x.rank)) if len(reps) < self.x.rank: reps = [1] * (self.x.rank - len(reps)) + list(reps) out_shape = [] for i, rep in enumerate(reps): if not is_symbolic(rep): if rep <= 0: raise ValueError( "All entries of reps parameter must be greater than 0") if is_symbolic(rep) or is_symbolic(x_shape[i]): out_shape.append(get_new_symbol()) else: out_shape.append(rep * x_shape[i]) out_shape = tuple(out_shape) return types.tensor(x_type, out_shape)
def type_inference(self): typea = self.x.sym_type typeb = self.y.sym_type primitive_type = promoted_primitive_type(typea, typeb) if primitive_type is None: raise ValueError( "Incompatible primitive types in broadcast operation") primitive_type = self.get_dtype(primitive_type) # broadcast if not types.is_tensor(typea) and not types.is_tensor(typeb): # both typea and typeb are not tensors return primitive_type if types.is_tensor(typea) and not types.is_tensor(typeb): # a is tensor, b is not return types.tensor(primitive_type, typea.get_shape()) if not types.is_tensor(typea) and types.is_tensor(typeb): # a is not tensor, b is return types.tensor(primitive_type, typeb.get_shape()) # both a, b are tensors shapea = list(typea.get_shape()) shapeb = list(typeb.get_shape()) ret_shape = broadcast_shapes(shapea, shapeb) return types.tensor(primitive_type, ret_shape)
def type_inference(self): if self.x.rank != 3: raise ValueError( "Invalid input shape. Expecting Rank 3 input, got {}".format( len(self.x.rank))) sequence_length, batch_size, input_size = self.x.shape if self.weight_ih.rank != 2 or self.weight_hh.rank != 2: raise ValueError( "Invalid weight shape. Expecting Rank 2 input, got weight_ih {}, weight_hh {}" ).format(self.weight_ih.rank, self.weight_hh.rank) hidden_size, _ = self.weight_ih.shape direction = self.direction.val valid_directions = {"forward", "reverse"} if direction not in valid_directions: raise ValueError( "Direction {} not supported. Supported directions: {}".format( direction, valid_directions)) out_seq_len = sequence_length if self.output_sequence.val else 1 output_shape = [out_seq_len, batch_size, hidden_size] output_h_shape = [batch_size, hidden_size] return ( types.tensor(self.x.dtype, tuple(output_shape)), types.tensor(self.x.dtype, tuple(output_h_shape)), )
def type_inference(self): # Input shape is [n, C_in, spatial_dims] in_shape = self.x.shape # Weight shape is [C_in, C_out/group, spatial_dims] f_shape = self.weight.shape kernel_shape = f_shape[2:] spatial_dim_rank = len(in_shape) - 2 N = in_shape[0] C_in = self.x.shape[0] groups = self.groups.val C_out = f_shape[1] * groups if self.bias is not None and self.bias.val.shape[0] != C_out: msg = "# of bias values {} not equal to # output channels {}" raise ValueError(msg.format(self.bias.val.shape[0], C_out)) if C_out % groups != 0: msg = "# of input channels {} not divisible by groups {}" raise ValueError(msg.format(C_in, groups)) # If output shape is given, return it if self.output_shape is not None: output_shape = self.output_shape.val assert output_shape[0] == N assert output_shape[1] == C_out return types.tensor( self.x.dtype, tuple(output_shape) ) strides = self.strides.val dilations = self.dilations.val kernel_shape = [ (kernel_shape[r] - 1) * dilations[r] + 1 for r in range(spatial_dim_rank) ] D_in = in_shape[2:] # spatial dimensions # Deconv's output shape is non-deterministic, we follow TF shape logic here. if self.pad_type.val == "same": d_out_shape = [strides[r] * D_in[r] for r in range(spatial_dim_rank)] elif self.pad_type.val == "valid": d_out_shape = [ strides[r] * (D_in[r]-1) + kernel_shape[r] for r in range(spatial_dim_rank) ] elif self.pad_type.val == "custom": if self.pad is None: raise ValueError("self.pad must exist if pad_type is custom") pad = self.pad.val d_out_shape = [ strides[r] * (D_in[r] - 1) + kernel_shape[r] - pad[2 * r] - pad[2 * r + 1] for r in range(spatial_dim_rank) ] retshape = [N, C_out] + d_out_shape return types.tensor(self.x.dtype, tuple(retshape))
def type_inference(self): if any_symbolic(self.shape.shape): # We can't infer any shape if shape has variable length. return types.tensor(types.fp32, (get_new_variadic_symbol(),)) # shape has fixed length here. if self.shape.sym_val is None: ret_shape = tuple([get_new_symbol() for _ in range(self.shape.shape[0])]) return types.tensor(types.fp32, ret_shape) return types.tensor(self.value.dtype, tuple(self.shape.sym_val.tolist()))
def type_inference(self): if any_symbolic(self.shape.shape): # We can't infer any shape if shape has variable length. return types.tensor(self.x.dtype, (get_new_variadic_symbol(),)) # shape has fixed length here. if self.shape.sym_val is None: shape = tuple([get_new_symbol() for _ in range(self.shape.shape[0])]) return types.tensor(self.x.dtype, shape) t, _ = self._get_type_val() return t
def type_inference(self): boxes_dtype = self.boxes.dtype scores_dtype = self.scores.dtype n_batch, _, n_score = self.scores.shape max_boxes = self.max_boxes.val return ( types.tensor(boxes_dtype, (n_batch, max_boxes, 4)), types.tensor(scores_dtype, (n_batch, max_boxes, n_score)), types.tensor(types.int32, (n_batch, max_boxes)), types.tensor(types.int32, (n_batch, )), )
def type_inference(self): x_type = self.x.dtype x_shape = self.x.shape k = self.k.val axis = self.axis.val if not is_symbolic(x_shape[axis]) and k > x_shape[axis]: msg = "K={} is greater than size of the given axis={}" raise ValueError(msg.format(k, axis)) ret_shape = list(x_shape) ret_shape[axis] = k return types.tensor(x_type, ret_shape), types.tensor(types.int32, ret_shape)
def type_inference(self): if self.x.rank != 3: raise ValueError( "Invalid input shape. Expecting Rank 3 input, got {}".format( len(self.x.rank) ) ) sequence_length, batch_size, input_size = self.x.shape def weight_shape_check(wt_ih, wt_hh): if wt_ih.rank != 2 or wt_hh.rank != 2: raise ValueError( "Expecting Rank 2 input, got weight_ih rank: {}, weight_hh rank: {}").format( wt_ih.rank, wt_hh.rank ) hidden_size = wt_hh.shape[1] if wt_hh.shape[0] // hidden_size != 4 or wt_ih.shape[0] // hidden_size != 4: raise ValueError( "Incorrect weight matrix: hidden dim size mismatch. \ Provided weight_ih {}, weight_hh {}. Expecting <4*H, H>").format( wt_ih.shape, wt_hh.shape ) direction = self.direction.val valid_directions = {"forward", "reverse", "bidirectional"} if direction not in valid_directions: raise ValueError( "Direction {} not supported. Supported directions: {}").format( direction, valid_directions ) weight_shape_check(self.weight_ih, self.weight_hh) if direction == "bidirectional": weight_shape_check(self.weight_ih_back, self.weight_hh_back) hidden_dim, hidden_size = self.weight_hh.shape dim_factor = 8 if direction == "bidirectional" else 4 out_seq_len = sequence_length if self.output_sequence.val else 1 num_directions = dim_factor // 4 output_shape = [out_seq_len, batch_size, num_directions * hidden_size] output_h_shape = [batch_size, num_directions * hidden_size] output_c_shape = [batch_size, num_directions * hidden_size] return ( types.tensor(self.x.dtype, tuple(output_shape)), types.tensor(self.x.dtype, tuple(output_h_shape)), types.tensor(self.x.dtype, tuple(output_c_shape)), )
def type_inference(self): if self.x.rank != 3: raise ValueError( "Invalid input shape. Expecting Rank 3 input, got {}".format( len(self.x.rank) ) ) sequence_length, batch_size, input_size = self.x.shape if self.weight_ih.rank != 2: raise ValueError( "Invalid weight shape. Expecting Rank 2 input, got {}".format( len(self.weight_ih.rank) ) ) if self.weight_hh.rank != 2: raise ValueError( "Invalid weight shape. Expecting Rank 2 input, got {}".format( len(self.weight_hh.rank) ) ) hidden_dim, hidden_size = self.weight_hh.shape direction = self.direction.val valid_directions = {"forward", "reverse"} if direction not in valid_directions: raise ValueError( "Direction {} not supported. Supported directions: {}".format( direction, valid_directions ) ) dim_factor = 3 if hidden_size != (hidden_dim // dim_factor): raise ValueError( "Incorrect weight matrix: hidden dim size mismatch. \ Provided weight_ih {}, weight_hh {}. Expecting <b, 3*H>").format( self.weight_ih.shape, self.weight_hh.shape ) out_seq_len = sequence_length if self.output_sequence.val else 1 output_shape = [out_seq_len, batch_size, hidden_size] output_h_shape = [batch_size, hidden_size] return ( types.tensor(self.x.dtype, tuple(output_shape)), types.tensor(self.x.dtype, tuple(output_h_shape)), )
def type_inference(self): rank = self.x.rank # check valid axes positive_axes = [ axis + rank if axis < 0 else axis for axis in self.axes.val ] if not all([axis >= 0 and axis < rank for axis in positive_axes]): raise ValueError("axes must in the range of [-x.rank, x.rank-1].") # check shape of gamma and beta normalized_shape = [ self.x.shape[i] for i in range(rank) if i in positive_axes ] if self.gamma is not None and not layer_norm._is_compatible_shape( list(self.gamma.shape), normalized_shape): raise ValueError( "Expect shape {} for gamma, but get shape {} instead".format( normalized_shape, self.gamma.shape)) if self.beta is not None and not layer_norm._is_compatible_shape( list(self.gamma.shape), normalized_shape): raise ValueError( "Expect shape {} for beta, but get shape {} instead".format( normalized_shape, self.beta.shape)) x_shape = self.x.shape return types.tensor(self.x.dtype, tuple(x_shape))
def type_inference(self): if self.x.rank != 4: raise ValueError( 'input "x" to the "resample" op must be a rank 4 tensor. ' "Got rank {} tensor of shape {}".format( self.x.rank, self.x.shape)) if self.coordinates.rank != 4: raise ValueError( 'input "coordinates" to the "resample" op must be a rank 4 tensor. ' "Got rank {} tensor of shape {}".format( self.coordinates.rank, self.coordinates.shape)) input_shape = self.x.shape coord_shape = self.coordinates.shape if (not is_symbolic(input_shape[0]) and not is_symbolic(coord_shape[0]) and input_shape[0] != coord_shape[0]): raise ValueError( 'input "x" and "coordinates" to the "resample" must agree on ' "dimension of batch size: {} vs. {}".format( input_shape[0], coord_shape[0])) if not is_symbolic(coord_shape[-1]) and coord_shape[-1] != 2: raise ValueError( 'input "coordinates" to the "resample" op last dimension must be 2. ' "Got {} for last dimension".format(coord_shape[-1])) ret_shape = list(input_shape) ret_shape[2] = coord_shape[1] # Output height ret_shape[3] = coord_shape[2] # Output width return types.tensor(self.x.dtype, tuple(ret_shape))
def type_inference(self): if self.x.rank != 4: raise ValueError( 'input to the "crop_resize" op must be of rank 4. Provided {}'. format(self.x.rank)) if self.roi.rank != 5: raise ValueError( 'ROI input to the "crop_resize" op must be of rank 5, provided {}' .format(self.roi.rank)) if self.sampling_mode.val not in { "STRICT_ALIGN_CORNERS", "ALIGN_CORNERS", "UNALIGN_CORNERS", "DEFAULT", "OFFSET_CORNERS", }: raise ValueError( '"crop_resize" op: unrecognized sampling mode "{}"'.format( self.sampling_mode)) # ret_shape: [N] + [B, C, h_out, w_out] N, B, C = self.roi.shape[0], self.x.shape[0], self.x.shape[1] ret_shape = [N, B, C, self.target_height.val, self.target_width.val] return types.tensor(self.x.dtype, ret_shape)
def parse_tensor(t): typ = parse_type(t.dtype) shape = parse_shape(t.tensor_shape) retval = None if len(t.half_val) > 0: retval = _np.array(t.half_val, dtype=_TF_TO_NP[t.dtype]) elif len(t.float_val) > 0: retval = _np.array(t.float_val, dtype=_TF_TO_NP[t.dtype]) elif len(t.double_val) > 0: retval = _np.array(t.double_val, dtype=_TF_TO_NP[t.dtype]) elif len(t.int_val) > 0: retval = _np.array(t.int_val, dtype=_TF_TO_NP[t.dtype]) elif len(t.int64_val) > 0: retval = _np.array(t.int64_val, dtype=_TF_TO_NP[t.dtype]) elif len(t.bool_val) > 0: retval = _np.array(t.bool_val, dtype=_TF_TO_NP[t.dtype]) elif hasattr(t, "uint32_val") and len(t.uint32_val) > 0: retval = _np.array(t.uint32_val, dtype=_TF_TO_NP[t.dtype]) elif hasattr(t, "uint64_val") and len(t.uint64_val) > 0: retval = _np.array(t.uint64_val, dtype=_TF_TO_NP[t.dtype]) if not t.tensor_shape.unknown_rank and len(shape) == 0: retobj = typ() if retval is not None: retobj.val = retval[0] else: rettype = types.tensor(typ, tuple(shape)) retobj = rettype() retobj.shape = shape if retval is not None: retobj.val = retval return retobj
def type_inference(self): in_shape = self.x.shape ret_shape = list(in_shape) pad = self.pad if len(pad.shape) != 1: raise ValueError("Pad should be a 1D tensor!") if self.mode and not self.mode.val in {'constant', 'reflect', 'replicate'}: raise ValueError("Pad mode should be one of {'constant', 'reflect', 'replicate'}") if pad.val is None: for i in range(self.pad.shape[0]//2): ret_shape[-self.pad.shape[0]//2+i] = get_new_symbol() else: pad = pad.val pad = pad.copy() if len(pad) % 2 != 0: raise ValueError("Number of elements in the argument Pad must be divisible by 2.") pad = pad.reshape(-1, 2) if pad.shape[0] > len(ret_shape): raise ValueError("Number of dimensions specified through pad must less than or equal to rank of input x") for i in range(len(pad)): ret_shape[-len(pad) + i] = ret_shape[-len(pad) + i] + pad[i][0] + pad[i][1] return types.tensor(self.x.dtype, tuple(ret_shape))
def type_inference(self): # get tensor and set default value begin = self.begin.val end = self.end.val x_rank = self.x.rank stride = self.stride.val if self.stride is not None else [1] * x_rank begin_mask = (self.begin_mask.val if self.begin_mask is not None else [False] * x_rank) end_mask = self.end_mask.val if self.end_mask is not None else [ False ] * x_rank squeeze_mask = (self.squeeze_mask.val if self.squeeze_mask is not None else [False] * x_rank) # solve shape x_shape = self.x.shape ret_shape = solve_slice_by_index_shape(x_shape, begin, end, stride, begin_mask, end_mask, squeeze_mask) if len(ret_shape) == 0: # Scalar case. return self.x.dtype else: return types.tensor(self.x.dtype, tuple(ret_shape))
def type_inference(self): builtin_dtype = types.string_to_builtin(self.dtype.val) if builtin_dtype is None: raise ValueError("Unsupported dtype {}".format(self.dtype.val)) # Replace string with symbol elem_shape_sym = [] for s_var in self.elem_shape: # s is str or int s = s_var.val if s is None: msg = 'make_list elem_shape must be tuple of const. ' +\ 'Tuple elem {} is not' raise ValueError(msg.format(s_var.name)) if isinstance(s, str): try: symbol = get_existing_symbol(s) except ValueError: # Must be a new symbol symbol = get_new_symbol(s) elem_shape_sym.append(symbol) else: elem_shape_sym.append(s) elem_type = types.tensor(builtin_dtype, elem_shape_sym) return types.list( elem_type, init_length=self.init_length.val, dynamic_length=self.dynamic_length.val, )
def type_inference(self): ksize = self.kernel_sizes.val x_shape = self.x.shape D_in_rank = len(x_shape) - 2 strides = [1] * D_in_rank if self.strides is None else self.strides.val pad_type = "valid" if self.pad_type is None else self.pad_type.val.lower() if pad_type not in ["valid", "same", "custom"]: raise ValueError("Unrecognized value of pad_type : {}".format(pad_type)) pad = None if self.pad is None else self.pad.val D_in = x_shape[2:] # spatial dimensions if self.ceil_mode.val: if D_in_rank > 2: raise ValueError('pool: ceil_mode only supported for 1D or 2D pool') if pad_type == "same" and self.ceil_mode.val: raise ValueError("ceil_mode must be False when pad_type==same") if pad is not None: for i in range(D_in_rank): if pad[2*i] != pad[2*i+1]: raise ValueError("Padding must be symmetric if ceil_mode is True") D_out_shape = spatial_dimensions_out_shape( pad_type=pad_type, input_shape=D_in, kernel_shape=ksize, strides=strides, custom_pad=pad, ceil_mode=self.ceil_mode.val, ) ret_shape = list(x_shape[:2]) + D_out_shape return types.tensor(self.x.dtype, tuple(ret_shape))
def type_inference(self): in_shape = self.x.shape ret_shape = list(in_shape) pad = self.pad if len(pad.shape) != 1: raise ValueError("Pad should be a 1D tensor!") if self.mode and not self.mode.val in { 'constant', 'reflect', 'replicate' }: raise ValueError( "Pad mode should be one of {'constant', 'reflect', 'replicate'}" ) if pad.val is None: for i in range(self.pad.shape[0] // 2): ret_shape[-self.pad.shape[0] // 2 + i] = get_new_symbol() else: pad = pad.val pad = pad.copy() pad = pad.reshape(-1, 2) for i in range(len(pad)): ret_shape[-len(pad) + i] = ret_shape[-len(pad) + i] + pad[i][0] + pad[i][1] return types.tensor(self.x.dtype, tuple(ret_shape))
def type_inference(self): if self.x.rank < 3: msg = "Input rank of l2_norm must be at least 3. Got {}".format( self.x.rank) raise ValueError(msg) x_shape = self.x.shape return types.tensor(self.x.dtype, tuple(x_shape))
def type_inference(self): if self.x.rank != self.indices.rank: raise ValueError( "Rank mismatch between input and indices. \ Input rank: {}, indices rank: {}".format( self.x.rank, self.indices.rank ) ) if self.axis.val < -self.x.rank or self.axis.val >= self.x.rank: raise IndexError( "Axis value {} is out of bounds for {} node {}".format( self.axis.val, self.op_type, self.name ) ) axis = self.axis.val axis = axis if axis >= 0 else axis + self.x.rank for i in range(self.x.rank): if i != axis: assert self.x.shape[i] == self.indices.shape[i] return types.tensor(self.x.dtype, self.indices.shape)
def type_inference(self): shape = list(self.x.shape) axis = self.axis.val dim_pre_axis = np.prod(shape[:axis]) dim_post_axis = np.prod(shape[axis:]) new_shape = [dim_pre_axis, dim_post_axis] return types.tensor(self.x.dtype, tuple(new_shape))
def type_inference(self): on_type = self.on_value.dtype off_type = self.off_value.dtype if on_type != off_type: raise TypeError( "Parameters on_value and off_value must have same input types." ) if self.axis.val < -self.indices.rank - 1 or self.axis.val > self.indices.rank: raise IndexError( "Axis value {} is out of bounds for {} node {}".format( self.axis.val, self.op_type, self.name)) indices_shape = list(self.indices.shape) depth_value = self.one_hot_vector_size.sym_val if depth_value is None: depth_value = get_new_symbol() elif depth_value < 0: raise ValueError( "Parameter one_hot_vector_size must be non-negative") retshape = indices_shape if self.axis.val < 0: cut = len(retshape) + self.axis.val + 1 else: cut = self.axis.val retshape = retshape[0:cut] + [depth_value] + retshape[cut:] return types.tensor(on_type, retshape)
def type_inference(self): concat_dim_len = 0 if len(self.values) == 0: raise ValueError("Concat {} got 0 values".format(self.name)) # Validate values have the same rank rank = self.values[0].rank for v in self.values: if v.rank != rank: msg = "Input {} has rank {} != other inputs rank {}" raise ValueError(msg.format(v.name, v.rank, rank)) # Check concat axis is within (-rank, rank) concat_axis = self.axis.val if concat_axis < 0: concat_axis += rank if rank > 0 and (concat_axis < 0 or concat_axis >= rank): msg = "In {} of op_type {}: axis out of bound for input " + "(rank {})" raise ValueError(msg.format(self.name, self.op_type, rank)) # Validate primitive types are compatible dtype = self.values[0].dtype for v in self.values[1:]: new_dtype = promoted_primitive_type(v.dtype, dtype) if new_dtype is None: msg = "Incompatible primitive types concat: {} vs {}" raise ValueError(msg.format(v.dtype, dtype)) dtype = new_dtype # validate that non-axis dimensions match retshape = list(self.values[0].shape) for v in self.values[1:]: for i in range(rank): if is_symbolic(retshape[i]) or is_symbolic(v.shape[i]): continue if i != concat_axis and retshape[i] != v.shape[i]: msg = 'Dimension mismatch in {} ("{}"): shapes {} vs. {}' raise ValueError( msg.format(self.op_type, self.name, retshape, v.shape) ) # Get length of concat dim concat_dim_len = 0 for v in self.values: if len(v.shape) == 0: taxis = 1 else: taxis = v.shape[concat_axis] if is_symbolic(taxis): concat_dim_len = get_new_symbol() break concat_dim_len += taxis if len(retshape) == 0: retshape = [concat_dim_len] else: retshape[concat_axis] = concat_dim_len return types.tensor(dtype, retshape)
def type_inference(self): num_splits, sizes = self._get_num_splits_and_sizes() x_shape = list(self.x.shape) ret_shapes = [x_shape[:] for _ in range(num_splits)] axis = self.axis.val for i, d in enumerate(sizes): ret_shapes[i][axis] = d return tuple([types.tensor(self.x.dtype, s) for s in ret_shapes])
def type_inference(self): x_type = self.x.dtype x_shape = self.x.shape y_shape = self.y.shape # For illustration purpose, assumming getting valid shape # Ideally, should consider transpose_?, ?_is_sparse parameters into consideration # for computing output shape return types.tensor(x_type, [x_shape[0], y_shape[1]])
def type_inference(self): if self.x.rank < 3: raise ValueError( 'input to the "torch_upsample_nearest_neighbor" op must have rank at least 3' ) ret_shape = list(self.x.shape) ret_shape[-1] = get_new_symbol() ret_shape[-2] = get_new_symbol() return types.tensor(self.x.dtype, ret_shape)