class topk(_topk_iOS15): """ An iOS16 version of topk Additional Parameters ---------- * sort: const<bool> (Optional) * Default to ``True`` * If true, top-k elements are themselves sorted. Otherwise, no particular ordering is guaranteed. * return_indices: const<bool> (Optional) # Default to ``True`` # If true, returns both values and indices. Otherwise, returns only the top-k values. Returns ------- tensor<\*?, T> * Values of top/bottom ``k`` elements. tensor<\*?, int32> * Only returned when ``return_indices = True`` * Indices of the top/bottom ``k`` elements along axis. Attributes ---------- T: fp32, int32 """ input_spec = _topk_iOS15.input_spec + InputSpec( sort=BoolInputType(const=True, optional=True), return_indices=BoolInputType(const=True, optional=True), ) def __init__(self, **kwargs): super().__init__(**kwargs) def default_inputs(self): return super().default_inputs() + DefaultInputs(sort=True, return_indices=True) def type_inference(self): value_type, indices_type = super().type_inference() if not self.return_indices.val: return value_type return value_type, indices_type @precondition(allow=VALUE) def value_inference(self): values, indices = super().value_inference() if not self.return_indices.val: return values return values, indices
class tf_make_list(Operation): input_spec = InputSpec( init_length=IntInputType(optional=True), dynamic_length=BoolInputType(optional=True), elem_shape=TensorInputType(const=True, optional=True), dtype=StringInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( init_length=1, dynamic_length=True, dtype="fp32", ) def __init__(self, **kwargs): super(tf_make_list, self).__init__(**kwargs) def type_inference(self): init_length = self.init_length.val if self.elem_shape is None or self.elem_shape.sym_val is None: return types.list( types.unknown, init_length=init_length, dynamic_length=self.dynamic_length.val, ) builtin_dtype = types.string_to_builtin(self.dtype.val) if builtin_dtype is None: raise ValueError("Unsupported dtype {}".format(self.dtype.val)) elem_type = types.tensor(builtin_dtype, self.elem_shape.sym_val) return types.list(elem_type, init_length=init_length, dynamic_length=self.dynamic_length.val)
class Pooling(Operation): """ Pooling Op Superclass """ input_spec = InputSpec( x=TensorInputType(), kernel_sizes=IntTensorInputType(const=True), strides=IntTensorInputType(const=True, optional=True), pad_type=StringInputType(const=True), pad=IntTensorInputType(const=True, optional=True), ceil_mode=BoolInputType(const=True, optional=True), ) def default_inputs(self): num_spatial_dims = self.x.rank - 2 return DefaultInputs( strides=[1]*num_spatial_dims, pad=[0]*2*num_spatial_dims, ceil_mode=False, ) def __init__(self, **kwargs): super().__init__(**kwargs) def type_inference(self): ksize = self.kernel_sizes.val x_shape = self.x.shape D_in_rank = len(x_shape) - 2 strides = [1] * D_in_rank if self.strides is None else self.strides.val pad_type = "valid" if self.pad_type is None else self.pad_type.val.lower() if pad_type not in ["valid", "same", "custom"]: raise ValueError("Unrecognized value of pad_type : {}".format(pad_type)) pad = None if self.pad is None else self.pad.val D_in = x_shape[2:] # spatial dimensions if self.ceil_mode.val: if D_in_rank > 2: raise ValueError('pool: ceil_mode only supported for 1D or 2D pool') if pad_type == "same" and self.ceil_mode.val: raise ValueError("ceil_mode must be False when pad_type==same") if pad is not None: for i in range(D_in_rank): if pad[2*i] != pad[2*i+1]: raise ValueError("Padding must be symmetric if ceil_mode is True") D_out_shape = spatial_dimensions_out_shape( pad_type=pad_type, input_shape=D_in, kernel_shape=ksize, strides=strides, custom_pad=pad, ceil_mode=self.ceil_mode.val, ) ret_shape = list(x_shape[:2]) + D_out_shape return types.tensor(self.x.dtype, tuple(ret_shape))
class custom_torch_sparse_matmul(Operation): # Defining input spec for current op input_spec = InputSpec( x=TensorInputType(), y=TensorInputType(), transpose_x=BoolInputType(const=True, optional=True), transpose_y=BoolInputType(const=True, optional=True), x_is_sparse=BoolInputType(const=True, optional=True), y_is_sparse=BoolInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( transpose_x=False, transpose_y=False, x_is_sparse=False, y_is_sparse=False, ) # Specifying binding for custom op for specifying inputs, # parameters required for creating custom op to be synced with Swift API bindings = { "class_name": "SparseMatMul", "input_order": ["x", "y"], "parameters": ["transpose_x", "transpose_y", "x_is_sparse", "y_is_sparse"], "description": "Custom Sparse MatMul Layer", } def __init__(self, **kwargs): super(TestCustomOp.custom_torch_sparse_matmul, self).__init__(**kwargs) def type_inference(self): x_type = self.x.dtype x_shape = self.x.shape y_shape = self.y.shape # For illustration purpose, assumming getting valid shape # Ideally, should consider transpose_?, ?_is_sparse parameters into consideration # for computing output shape return types.tensor(x_type, [x_shape[0], y_shape[1]])
class torch_upsample_bilinear(Operation): """ Upsample the spatial dimensions (last two dimensions) of the input by scale factors using bilinear interpolation. It corresponds to `torch.nn.functional.interpolate` function with `mode=bilinear`, `recompute_scale_factor=True`, and input with flexible shape. source: https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#interpolate Parameters ---------- x: tensor<[\*D, H1, W1],T> (Required) * Must be rank ``3``. output_height: i32 * Output height for the height dimension. output_width: i32 * Output width for the width dimension. aligh_corners: const<bool> * The `aligh_corners` parameter for the original torch op. Returns ------- tensor<[\*D, H2, W2],T> * Tensor with same type as the input. * ``H2`` = output_height * ``W2`` = output_width Attributes ---------- T: fp32 """ input_spec = InputSpec( x=TensorInputType(), output_height=IntOrFloatInputType(), output_width=IntOrFloatInputType(), align_corners=BoolInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( align_corners=True, ) def __init__(self, **kwargs): super(torch_upsample_bilinear, self).__init__(**kwargs) def type_inference(self): if self.x.rank < 3: raise ValueError( 'input to the "torch_upsample_bilinear" op must have rank at least 3' ) ret_shape = list(self.x.shape) ret_shape[-1] = get_new_symbol() ret_shape[-2] = get_new_symbol() return types.tensor(self.x.dtype, ret_shape)
class ReductionAxes(Operation): """ Reduction Op Superclasses """ input_spec = InputSpec( x=TensorInputType(), axes=IntTensorInputType(const=True, optional=True), keep_dims=BoolInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( axes=None, keep_dims=False, ) def __init__(self, **kwargs): super().__init__(**kwargs) def type_inference(self): x_type = self.x.dtype x_shape = self.x.shape axes = self.axes.val if self.axes is not None else None if axes is None: axes = range(self.x.rank) keep_dims = self.keep_dims.val reduced_shape = list(x_shape) if keep_dims: for i in axes: reduced_shape[i] = 1 else: # sort reverse so we can delete shape elements back to front axes = [ axis if axis >= 0 else axis + len(reduced_shape) for axis in axes ] for i in sorted(axes)[::-1]: reduced_shape.pop(i) if len(reduced_shape) == 0: return x_type # scalar return types.tensor(x_type, tuple(reduced_shape)) @precondition(allow=VALUE) def value_inference(self): axes = tuple(self.axes.val) if self.axes is not None else None return self.get_operator()(self.x.val, axis=axes, keepdims=self.keep_dims.val) def get_operator(self): raise NotImplementedError()
class argsort(Operation): """ Returns a tensor containing the indices of the sorted values along a given axis of the input tensor. Parameters ---------- x: <\*?, T> (Required) * Input tensor. * axis: const<i32> (Optional) * Default to ``-1`` (the last dimension). * Axis to perform the operation. * ascending: const<bool> (Optional) * Default to ``False``, sort in descending order. ``True`` to sort in ascending order. Returns ------- tensor<\*?, int32> * Tensor containing the indices of the sorted values Attributes ---------- T: fp32 """ input_spec = InputSpec( x=TensorInputType(), axis=IntInputType(const=True, optional=True), ascending=BoolInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( axis=-1, ascending=False, ) def __init__(self, **kwargs): super(argsort, self).__init__(**kwargs) def type_inference(self): return types.tensor(types.int32, self.x.shape) @precondition(allow=VALUE) def value_inference(self): # The default np argsort mode is ascending, which is opposite to MIL's argsort op. if self.ascending.val: return np.argsort(self.x.val, axis=self.axis.val) return np.argsort(-self.x.val, axis=self.axis.val)
class resample(_resample_iOS15): """ iOS16 version of resample supports float16 coordinates """ input_spec = InputSpec( x=TensorInputType(), coordinates=ScalarOrTensorInputType(type_domain=(np.int32, np.float32, np.float16)), sampling_mode=StringInputType(const=True), padding_mode=StringInputType(const=True), padding_value=FloatInputType(const=True), coordinates_mode=StringInputType(const=True), align_corners=BoolInputType(const=True), ) def __init__(self, **kwargs): super().__init__(**kwargs) def type_inference(self): return super().type_inference()
class ReductionAxis(Operation): input_spec = InputSpec( x=TensorInputType(), axis=IntInputType(const=True, optional=True), keep_dims=BoolInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( axis=-1, keep_dims=False, ) def __init__(self, **kwargs): super().__init__(**kwargs) def _find_reduced_shape(self): x_shape = self.x.shape axis = self.axis.val reduced_shape = list(x_shape) axis = axis if axis >= 0 else axis + len(reduced_shape) if self.keep_dims.val: reduced_shape[axis] = 1 else: reduced_shape.pop(axis) return reduced_shape def type_inference(self): x_type = self.x.dtype reduced_shape = self._find_reduced_shape_and_axis() return types.tensor(x_type, tuple(reduced_shape)) @precondition(allow=VALUE) def value_inference(self): tmp = self.get_operator()(self.x.val, axis=self.axis.val) reduced_shape = self._find_reduced_shape() if self.keep_dims.val: tmp = np.reshape(tmp, reduced_shape) return tmp def get_operator(self): raise NotImplementedError()
class custom_topk(Operation): input_spec = InputSpec( x=TensorInputType(), k=IntInputType(const=True, optional=True), axis=IntInputType(const=True, optional=True), sorted=BoolInputType(const=True, optional=True), ) bindings = { "class_name": "TopK", "input_order": ["x"], "parameters": ["k", "axis", "sorted"], "description": "Top K Custom layer", } def default_inputs(self): return DefaultInputs( k=1, axis=-1, sorted=False, ) def __init__(self, **kwargs): super(custom_topk, self).__init__(**kwargs) def type_inference(self): x_type = self.x.dtype x_shape = self.x.shape k = self.k.val axis = self.axis.val if not is_symbolic(x_shape[axis]) and k > x_shape[axis]: msg = "K={} is greater than size of the given axis={}" raise ValueError(msg.format(k, axis)) ret_shape = list(x_shape) ret_shape[axis] = k return types.tensor(x_type, ret_shape), types.tensor( types.int32, ret_shape)
class TfLSTMBase(Operation): """ Common LSTM inputs for BlockLSTMCell and BlockLSTM. """ input_spec = InputSpec( c_prev=TensorInputType(), # [batch, hidden_dim] h_prev=TensorInputType(), # [batch, hidden_dim] # weight: [input_dim + hidden_dim, 4*hidden_dim] (icfo layout) weight=TensorInputType(const=True), forget_bias=FloatInputType(const=True, optional=True), # cell_clip == None implies not using cell clip cell_clip=FloatInputType(const=True, optional=True), # If use_peephole == False, weight_peep_* is ignored use_peephole=BoolInputType(const=True, optional=True), weight_peep_i=TensorInputType(const=True, optional=True), # [hidden_dim,] weight_peep_f=TensorInputType(const=True, optional=True), # [hidden_dim,] weight_peep_o=TensorInputType(const=True, optional=True), # [hidden_dim,] bias=TensorInputType(const=True), # [4*hidden_dim] (icfo layout) ) def default_inputs(self): return DefaultInputs( forget_bias=1., use_peephole=False, ) def _check_peephole_weights(self): # Check weight_peep_* if self.use_peephole.val: if (self.weight_peep_i is None or self.weight_peep_f is None or self.weight_peep_o is None): raise ValueError( "weight_peep_* cannot be None when use_peephole is True")
class cond(Operation): """ Perform a conditional execution. The return types must be identical between the true and false branches. Parameters ---------- pred: tensor<[], bool> (Required) * 0-D tensor (scalar) predicate to switch between true and false branches. _true_fn: function (Required) * A Python function that executes if ``pred`` evaluates to ``True``. * It must take zero input (i.e, no input), and return one or more values whose type becomes the operation's return type. _false_fn: function (Required) * A Python function that executes if ``pred`` evaluates to ``False``. * It must take zero input (i.e. no input), and have return types that match those of the ``if`` branch. Returns ------- tuple * Python tuple of ``Variables`` from one of the branches. """ input_spec = InputSpec( pred=BoolInputType(), _true_fn=PyFunctionInputType(), _false_fn=PyFunctionInputType(), ) def __init__(self, **kwargs): super(cond, self).__init__(**kwargs) def build_nested_blocks(self): # Cond block true_block_name = self.name + "_true" with Block(name=true_block_name, outer_op=self) as true_block: true_func = self._true_fn.val true_ret_vars = true_func() if isinstance(true_ret_vars, tuple): true_ret_vars = list(true_ret_vars) if not isinstance(true_ret_vars, list): true_ret_vars = [true_ret_vars] true_block.set_outputs(true_ret_vars) self.blocks.append(true_block) false_block_name = self.name + "_false" with Block(name=false_block_name, outer_op=self) as false_block: false_func = self._false_fn.val false_ret_vars = false_func() if isinstance(false_ret_vars, tuple): false_ret_vars = list(false_ret_vars) if not isinstance(false_ret_vars, list): false_ret_vars = [false_ret_vars] false_block.set_outputs(false_ret_vars) self.blocks.append(false_block) def type_inference(self): true_ret_vars = self.blocks[0].outputs false_ret_vars = self.blocks[1].outputs # Verify true_ret_vars has the same types as false_ret_vars for i, (vt, vf) in enumerate(zip(true_ret_vars, false_ret_vars)): if not is_compatible_type(vt.sym_type, vf.sym_type): msg = ("true branch output {} type {} mismatch false branch" + " output type {}") raise ValueError( msg.format(vt.name, vt.sym_type.__type_info__(), vf.sym_type.__type_info__())) return tuple(v.sym_type for v in true_ret_vars) def value_inference(self): if self.pred.val is None: raise NotImplementedError() if self.pred.val: return [v.val for v in self.blocks[0].outputs] return [v.val for v in self.blocks[1].outputs]
class lstm(Operation): r""" Single long short-term memory (LSTM) sequence. .. math:: i_t = \rm{recurrent\_activation}(W_{ii} x_t + B_{ii} + W_{hi} h_(t-1) + B_{hi}) .. math:: f_t = \rm{recurrent\_activation}(W_{if} x_t + B_{if} + W_{hf} h_(t-1) + B_{hf}) .. math:: z_t = cell_activation(W_{iz} x_t + B_{iz} + W_{hz} h_(t-1) + B_{hz}) .. math:: o_t = \rm{recurrent\_activation}(W_{io} x_t + B_{io} + W_{ho} h_(t-1) + B_{ho}) .. math:: c_t = f_t * c_(t-1) + i_t * z_t .. math:: h_t = o_t * activation(c_t) Where: * ``i_t``, ``f_t``, ``o_t``, and ``z_t`` are input, forget, output, and cell gates, respectively, at time ``t``. * ``c_t`` is cell state at time ``t``. * ``h_t`` is the hidden state at time ``t``. * ``W_{ii}``, ``W_{if}``, ``W_{io}``, and ``W_{iz}`` are input weights for input, forget, output and cell gate, respectively. * ``W_{hi}``, ``W_{hf}``, ``W_{ho}``, and ``W_{hz}`` are recurrent weights for input, forget, output and cell gate, respectively. Parameters ---------- x: <s, b, I, T> (Required) * ``s`` is the sequence length, ``b`` is the batch size, and ``I`` is the input dimension. initial_h: <b, DIRECTION*H, T> (Required) * Initial hidden state. ``DIRECTION = 1`` for uni-directional, ``2`` for bi-directional LSTM. * ``H`` denotes hidden size. * ``[b, :H]`` and ``[b, H:]`` represents forward and reverse direction values, respectively. initial_c: <b, DIRECTION*H, T> (Required) * Initial cell state. * Format is same as ``initial_h``. weight_ih: const<4*H, I, T> (Required) * Input-hidden weight matrix * Weight tensor should be in order of ``[input_gate, forget_gate, output_gate, cell_gate]``. * If direction=="bidirectional", this is applied in forward direction. * If direction=="forward" or "backward" these weights are used. weight_hh: const<4*H, H, T> (Required) * Hidden-hidden weight matrix. * Weight tensor should be in order of ``[input_gate, forget_gate, output_gate, cell_gate]``. * If direction=="bidirectional", this is applied in forward direction. * If direction=="forward" or "backward" these weights are used. bias: const<4*H, T> (Optional) [Default all 0s] * bias = input-hidden bias + hidden-hidden bias * If direction=="bidirectional", this is applied in forward direction. * If direction=="forward" or "backward" this bias are used. peephole: const<3*H, T> (Optional, default to 0) * Weight tensor for peephole. * Order is ``[input_gate, forget_gate, output_gate]``. * Shape of each peephole vector is ``(H,)`` (``H`` is hidden size). * If direction=="bidirectional", this is applied in forward direction. * If direction=="forward" or "backward" these weights are used. weight_ih_back: const<4*H, I, T> (Optional) - * Input-hidden weight matrix for backward direction for `bidirectinal LSTM`. * Weight tensor should be in order of ``[input_gate, forget_gate, output_gate, cell_gate]``. * Must be provided for `bidirectional LSTM`. * This is only used when `direction` is "bidirectional". * For direction="reverse" use `weight_ih` instead. weight_hh_back: const<4*H, H, T> (Optional) - Hidden-hidden weight matrix * Hidden-hidden weight matrix for backward direction for `bidirectinal LSTM`. * Weight tensor should be in order of ``[input_gate, forget_gate, output_gate, cell_gate]``. * Must be provided for `bidirectional LSTM`. * This is only used when `direction` is "bidirectional". * For direction="reverse" use `weight_hh` instead. bias_back: const<4*H, T> (Optional) [Default all 0s] * bias = input-hidden bias + hidden-hidden bias. * Bias of backward direction for `bidirectional lstm` * This is only used when `direction` is "bidirectional". * For direction="reverse" use `bias` instead. peephole_back: const<3*H, T> (Optional, default to 0) * Weight tensor for peephole in backward direction for `bidirectional LSTM`. * Order is ``[input_gate, forget_gate, output_gate]``. * Shape of each peephole vector is ``(H,)`` (``H`` is hidden size). * Peephole of backward direction for `bidirectional lstm` * Bias of backward direction for `bidirectional lstm` * This is only used when `direction` is "bidirectional". * For direction="reverse" use `peephole` instead. direction: const<str> (Optional) [Default=forward] * One of the following: ``forward``, ``reverse``, or ``bidirectional``. * Must match ``DIRECTIONAL`` in initial states and weight parameters. output_sequence: const<bool> (Optional) [Default=False] * Outputs every step if ``True``. recurrent_activation: const<str> (Optional) [Default=sigmoid] * Activation applied on input, forget, and output gates. cell_activation: const<str> (Optional) [Default=tang] * Activation applied on cell gate. activation: const<str> (Optional) [Default=tanh] * Activation applied on output gate. clip: const<fp32> (optional) [Default=None] * Cell gate is clipped to ``[-clip, +clip]``. Returns ------- <s, b, DIRECTION*H, T> or <1, b, DIRECTION*H, T> * If ``output_sequence == True`` (hidden states from every step): ``<s, b, DIRECTION*H, T>``. * Else ``<1, b, DIRECTION*H, T>`` (hidden states of the final step). <b, DIRECTION*H, T> * Hidden states of the final step. <b, DIRECTION*H, T> * Memory state of the final step. Attributes ---------- T: fp32 """ input_spec = InputSpec( x=TensorInputType(), initial_h=TensorInputType(), initial_c=TensorInputType(), weight_ih=TensorInputType(const=True), # ifoz layout, weight_hh=TensorInputType(const=True), # ifoz layout bias=TensorInputType(const=True, optional=True), # ifoz layout peephole=TensorInputType(const=True, optional=True), # ifo layout weight_ih_back=TensorInputType(const=True, optional=True), # ifoz layout, weight_hh_back=TensorInputType(const=True, optional=True), # ifoz layout bias_back=TensorInputType(const=True, optional=True), # ifoz layout peephole_back=TensorInputType(const=True, optional=True), # ifo layout direction=StringInputType(const=True, optional=True), output_sequence=BoolInputType(const=True, optional=True), recurrent_activation=StringInputType(const=True, optional=True), cell_activation=StringInputType(const=True, optional=True), activation=StringInputType(const=True, optional=True), clip=FloatInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs(bias=None, direction="forward", output_sequence=False, recurrent_activation="sigmoid", cell_activation="tanh", activation="tanh", peephole=None, clip=None) def __init__(self, **kwargs): super(lstm, self).__init__(**kwargs) def type_inference(self): if self.x.rank != 3: raise ValueError( "Invalid input shape. Expecting Rank 3 input, got {}".format( len(self.x.rank))) sequence_length, batch_size, input_size = self.x.shape def weight_shape_check(wt_ih, wt_hh): if wt_ih.rank != 2 or wt_hh.rank != 2: raise ValueError( "Expecting Rank 2 input, got weight_ih rank: {}, weight_hh rank: {}" ).format(wt_ih.rank, wt_hh.rank) hidden_size = wt_hh.shape[1] if wt_hh.shape[0] // hidden_size != 4 or wt_ih.shape[ 0] // hidden_size != 4: raise ValueError( "Incorrect weight matrix: hidden dim size mismatch. \ Provided weight_ih {}, weight_hh {}. Expecting <4*H, H>" ).format(wt_ih.shape, wt_hh.shape) direction = self.direction.val valid_directions = {"forward", "reverse", "bidirectional"} if direction not in valid_directions: raise ValueError( "Direction {} not supported. Supported directions: {}").format( direction, valid_directions) weight_shape_check(self.weight_ih, self.weight_hh) if direction == "bidirectional": weight_shape_check(self.weight_ih_back, self.weight_hh_back) hidden_dim, hidden_size = self.weight_hh.shape dim_factor = 8 if direction == "bidirectional" else 4 out_seq_len = sequence_length if self.output_sequence.val else 1 num_directions = dim_factor // 4 output_shape = [out_seq_len, batch_size, num_directions * hidden_size] output_h_shape = [batch_size, num_directions * hidden_size] output_c_shape = [batch_size, num_directions * hidden_size] return ( types.tensor(self.x.dtype, tuple(output_shape)), types.tensor(self.x.dtype, tuple(output_h_shape)), types.tensor(self.x.dtype, tuple(output_c_shape)), )
class non_maximum_suppression(Operation): """ Applies non-maximum suppression (NMS) on the input box coordinates according to their intersection-over-union (IoU). NMS selects a subset of bounding boxes in descending order of score, and removes boxes that have high intersection-over-union (IOU) overlap with previously-selected boxes. Parameters ---------- boxes: tensor<[n, B, 4], T> (Required) * Box coordinates on which to perform NMS. scores: tensor<[n, B, K], T> (Required) * Scores for each one of the boxes. iou_threshold: const<T> (Required) * The intersection over union (``IoU``) threshold over which boxes are suppressed. NMS remove all overlapping boxes with ``IoU > iou_threshold``. score_threshold: const<T> (Required) * Before IoU suppression is performed, boxes with class scores below this threshold are rejected. max_boxes: const<i32> (Required) * Maximum number of boxes to select. If the number of surviving boxes are less, output is padded up to this number. per_class_suppression: const<bool> (Optional) * Default to ``False``. * If ``True``, suppression is performed independently within boxes of each class. Returns ------- tensor<[n, max_boxes, 4], T> * Coordinates of selected boxes. tensor<[n, max_boxes, K], T> * Scores of selected boxes. tensor<[n, max_boxes], i32> * Indices of selected boxes. tensor<[n], i32> * Number of boxes selected for each batch. Attributes ---------- T: fp32 """ input_spec = InputSpec( boxes=TensorInputType(), scores=TensorInputType(), iou_threshold=FloatInputType(const=True), score_threshold=FloatInputType(const=True), max_boxes=IntInputType(const=True), per_class_suppression=BoolInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( per_class_suppression=False) def __init__(self, **kwargs): super(non_maximum_suppression, self).__init__(**kwargs) def type_inference(self): boxes_dtype = self.boxes.dtype scores_dtype = self.scores.dtype n_batch, _, n_score = self.scores.shape max_boxes = self.max_boxes.val return ( types.tensor(boxes_dtype, (n_batch, max_boxes, 4)), types.tensor(scores_dtype, (n_batch, max_boxes, n_score)), types.tensor(types.int32, (n_batch, max_boxes)), types.tensor(types.int32, (n_batch,)), )
class avg_pool(Pooling): """ Perform average pooling. Supports 1-D, 2-D, and 3-D pool (1, 2, or 3 spatial dimensions). Parameters ---------- x: tensor<[n,C_in,\*D_in],T> (Required) * ``3 <= rank <= 5``. * ``D_in`` are spatial dimensions, ``1 <= len(D_in) <= 3``. * ``C_in`` is the number of input channels or depth dimensions. * ``n`` is the batch dimension. kernel_sizes: const tensor<[K],T> (Required) * The size of the window for each spatial dimension ``D_in`` of the input tensor. * ``K == len(D_in)`` strides: const tensor<[S],i32> (Optional, default to all 1s) * Stride along each of the spatial dimensions. * ``S == len(D_in)``. pad_type: const str (Required) Must be one of ``valid``, ``same`` or ``custom``. * ``valid``: No padding. This is equivalent to custom pad with ``pad[i] = 0, for all i``. * ``same`` : This is equivalent to custom pad with ``pad[2*i] + pad[2*i+1] = kernel_size[i]``. * ``custom``: Specify custom padding in the parameter pad. note that ``same`` padding is equivalent to custom padding with ``pad[2*i] + pad[2*i+1] = kernel_size[i]``. pad: const<[P],i32> (Optional. Default to all 0s) * ``pad`` represents the number of elements to pad before and after each dimension: ``pad[2*i], pad[2*i+1]`` are the pad size before and after spatial dimension ``i``. * ``P = 2 * len(D_in)``. * ``pad`` should be specified if and only if ``pad_type == custom`` exclude_padding_from_average: const tensor<[], bool> (Optional, default to False) * If ``True``, padded values (0s) are excluded from the denominator count when computing the average over the kernel window. ceil_mode: const<bool> * Same as PyTorch's ``ceil`` mode. * ``ceil`` is used instead of floor in calculating the output size. * Optional, defaults to ``False``. * Only applicable when ``pad_type`` is ``valid`` or ``custom``. * When ``ceil_mode`` is True, padding must be symmetric; that is, if specified, ``pad[2*i] == pad[2*i+1]`` must hold. Returns ------- tensor<[n, C_out,\*D_out],T> * Same rank as ``x``. * When ``ceil_mode = False``: * ``D_out[i] = floor[(D_in[i] + pad[2*i] + pad[2*i+1] - kernel_sizes[i]) / strides[i]] +1, for i = 0, .., len(D_in) - 1`` is mathematically the same as (when all parameters involved are integers): * ``D_out[i] = ceil [(D_in[i] + pad[2*i] + pad[2*i+1] - kernel_size[i] - 1) / stride[i]], for i = 0, .., len(D_in) - 1``. * ``*D_out`` is all ones if ``global_pooling`` is ``true``. * When ``ceil_mode = True``: * ``D_out[i] = ceil[(D_in[i] + pad[2*i] + pad[2*i+1] - kernel_sizes[i]) / strides[i]] +1, for i = 0, .., len(D_in) - 1`` * If ``(D_out[i] - 1) * strides[i] >= D_in[i] + pad[2*i] and (pad[2*i] + pad[2*i+1] > 0)`` then ``D_out[i] = D_out[i] - 1``. * The first equation is same as: * ``D_out[i] = floor[(D_in[i] + pad[2*i] + pad[2*i+1] - kernel_sizes[i] + strides[i] - 1) / strides[i]] +1, for i = 0, .., len(D_in) - 1`` Attributes ---------- T: fp16, fp32 See Also -------- l2_pool, max_pool """ input_spec = ( InputSpec( exclude_padding_from_average=BoolInputType(const=True, optional=True)) + Pooling.input_spec ) def default_inputs(self): return super().default_inputs() + \ DefaultInputs( exclude_padding_from_average=False, ) def __init__(self, **kwargs): super().__init__(**kwargs)
class make_list(Operation): """ Create a list of tensor elements. The elements should have the same shape. The list is similar to an auto-resizing array. Parameters ---------- init_length: <i32> (Optional) * Initial length for the list. If ``dynamic_length`` is ``False``, ``init_length`` is the fixed length of the list throughout runtime. * Default is ``1``. dynamic_length: <bool> (Optional) * Initial length for the list. If ``dynamic_length`` is ``False``, ``init_length`` is the fixed length of the list throughout runtime. * Default is ``True``. elem_shape: <K,i32> (Required) * Non-symbolic 1-D tensor denoting the shape of elements. * If not provided, the resulting ``List`` won’t have the elementary shape info, which may cause backend errors. Remedy this with SSA passes. dtype: const<str> (Optional) * Element tensor’s ``dtype``. * Default is ``fp32``. Returns ------- List[*] """ input_spec = InputSpec( init_length=IntInputType(optional=True), dynamic_length=BoolInputType(const=True, optional=True), elem_shape=TupleInputType(), dtype=StringInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( init_length=1, dynamic_length=True, dtype="fp32", ) def __init__(self, **kwargs): super(make_list, self).__init__(**kwargs) def type_inference(self): builtin_dtype = types.string_to_builtin(self.dtype.val) if builtin_dtype is None: raise ValueError("Unsupported dtype {}".format(self.dtype.val)) # Replace string with symbol elem_shape_sym = [] for s_var in self.elem_shape: # s is str or int s = s_var.val if s is None: msg = 'make_list elem_shape must be tuple of const. ' +\ 'Tuple elem {} is not' raise ValueError(msg.format(s_var.name)) if isinstance(s, str): try: symbol = get_existing_symbol(s) except ValueError: # Must be a new symbol symbol = get_new_symbol(s) elem_shape_sym.append(symbol) else: elem_shape_sym.append(s) elem_type = types.tensor(builtin_dtype, elem_shape_sym) return types.list( elem_type, init_length=self.init_length.val, dynamic_length=self.dynamic_length.val, )
class gru(Operation): r""" Gated recurrent unit (GRU). .. math:: r_t = \rm{recurrent\_activation}(W_{ir} x_t + b_{ir} + W_{hr} h_{t-1} + b_{hr}) .. math:: z_t = \rm{recurrent\_activation}(W_{iz} x_t + b_{iz} + W_{hz} h_(t−1) + b_{hz}) .. math:: o_t = activation(W_{io} x_t + b_{io} + r_t * (W_{ho} h_(t−1) + b_{ho})) .. math:: h_t = (1 − z_t) * o_t + z_t * h_{(t−1)} Where: * ``W_{ir}``, ``W_{io}``, and ``W_{iz}`` state input-hidden weight for reset, output and update gate, respectively. * ``W_{h[r|o|z]}`` are recurrent weights on hidden state to reset, output, update gate. * ``h_t`` is the hidden state at time ``t``. * ``x_t`` is the input at time ``t``. * ``h_(t-1)`` is the hidden state of the layer at time ``t-1`` or the initial hidden state at time ``0``. * ``r_t``, ``o_t``, and ``z_t`` are the reset, new, and update gates, respectively. * ``*`` is elementwise product. Parameters ---------- x: <s, b, I, T> (Required) * ``s`` is the sequence length, ``b`` is the batch size, and ``I`` is the input dimension. initial_h: <b, H, T> (Required) * ``H`` denotes hidden size. weight_ih: const<3*H, I, T> (Required) - Weight matrix * ``weigh_ih = [W_{ir} | W_{io} | W_{iz}]`` where ``[a|b]`` denotes column concatenation and ``[a, b]`` denotes row concatenation. ``W_{ir}``, ``W_{io}``, and ``W_{iz}`` have shape ``(H, I)``. * This is used when direction="forward" or "reverse". weight_hh: const<3*H, H, T> (Required) - Weight matrix * ``weight_hh = [W_{hr} | W_{ho} | W_{hz}]``: ``W_{hr}``, ``W_{ho}``, and ``W_{hz}`` have shape ``(H, H)``. * This is used when direction="forward" or "reverse". bias: const<3*H, T> (Optional) [Default all 0s] * ``bias[0]`` are input-hidden and hidden-hidden bias. * ``3*H`` are biases for ``[b_{ir} + b_{hz}, b_{io}]``. * This is used when direction="forward" or "reverse". direction: const<str> (Optional) [Default=forward] * Either ``forward`` or ``reverse``. output_sequence: const<bool> (Optional) [Default=False] * Outputs every step if ``True``. recurrent_activation: const<str> (Optional) [Default=sigmoid] * Activation applied on update and reset gate. activation: const<str> (Optional) [Default=tanh] * Activation applied on output gate. Returns ------- <s, b, H, T> or <1, b, H, T> * If ``output_sequence == True`` (hidden states from every step): ``<s, b, H, T>``. * Else ``<1, b, H, T>`` (hidden states of the final step). <b, H, T> * Hidden states of the final step. Attributes ---------- T: fp32 """ input_spec = InputSpec(x=TensorInputType(), initial_h=TensorInputType(), weight_ih=TensorInputType(const=True), weight_hh=TensorInputType(const=True), bias=TensorInputType(const=True, optional=True), direction=StringInputType(const=True, optional=True), output_sequence=BoolInputType(const=True, optional=True), recurrent_activation=StringInputType(const=True, optional=True), activation=StringInputType(const=True, optional=True)) def default_inputs(self): return DefaultInputs( bias=None, direction="forward", output_sequence=False, recurrent_activation="sigmoid", activation="tanh", ) def __init__(self, **kwargs): super(gru, self).__init__(**kwargs) def type_inference(self): if self.x.rank != 3: raise ValueError( "Invalid input shape. Expecting Rank 3 input, got {}".format( len(self.x.rank))) sequence_length, batch_size, input_size = self.x.shape if self.weight_ih.rank != 2: raise ValueError( "Invalid weight shape. Expecting Rank 2 input, got {}".format( len(self.weight_ih.rank))) if self.weight_hh.rank != 2: raise ValueError( "Invalid weight shape. Expecting Rank 2 input, got {}".format( len(self.weight_hh.rank))) hidden_dim, hidden_size = self.weight_hh.shape direction = self.direction.val valid_directions = {"forward", "reverse"} if direction not in valid_directions: raise ValueError( "Direction {} not supported. Supported directions: {}".format( direction, valid_directions)) dim_factor = 3 if hidden_size != (hidden_dim // dim_factor): raise ValueError( "Incorrect weight matrix: hidden dim size mismatch. \ Provided weight_ih {}, weight_hh {}. Expecting <b, 3*H>" ).format(self.weight_ih.shape, self.weight_hh.shape) out_seq_len = sequence_length if self.output_sequence.val else 1 output_shape = [out_seq_len, batch_size, hidden_size] output_h_shape = [batch_size, hidden_size] return ( types.tensor(self.x.dtype, tuple(output_shape)), types.tensor(self.x.dtype, tuple(output_h_shape)), )
class topk(Operation): """ Returns a tensor containing top or bottom ``k`` values and the corresponding indices of the input tensor along a given axis. Parameters ---------- x: <\*?, T> (Required) * Input tensor. k: const<i32> (Optional) * Default to ``1``. * Number of values/indices to be computed along each axis. * axis: const<i32> (Optional) * Defaults to ``-1`` (last dimension). * Axis to perform the operation. * ascending: const<bool> (Optional) * Default to ``False``, sort in descending order. ``True`` to sort in ascending order. Returns ------- tensor<\*?, T> * Values of top/bottom ``k`` elements. tensor<\*?, int32> * Indices of the top/bottom ``k`` elements along axis. Attributes ---------- T: fp32, int32 """ input_spec = InputSpec( x=TensorInputType(), k=IntInputType(const=True, optional=True), axis=IntInputType(const=True, optional=True), ascending=BoolInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( k=1, axis=-1, ascending=False, ) def __init__(self, **kwargs): super(topk, self).__init__(**kwargs) def type_inference(self): x_type = self.x.dtype x_shape = self.x.shape k = self.k.val axis = self.axis.val if not is_symbolic(x_shape[axis]) and k > x_shape[axis]: msg = "K={} is greater than size of the given axis={}" raise ValueError(msg.format(k, axis)) ret_shape = list(x_shape) ret_shape[axis] = k return types.tensor(x_type, ret_shape), types.tensor(types.int32, ret_shape) @precondition(allow=VALUE) def value_inference(self): indices = np.argsort(self.x.val, axis=self.axis.val) if not self.ascending.val: indices = np.argsort(-self.x.val, axis=self.axis.val) slc = [slice(None)] * self.x.rank slc[self.axis.val] = slice(0, self.k.val) indices = indices[tuple(slc)] values = np.take_along_axis(self.x.val, indices, axis=self.axis.val) return values, indices
class rnn(Operation): """ Recurrent neural network (RNN). .. math:: h_t = activation(W_{ih} x_t + b_{ih} + W_{hh} h_(t−1) + b_{hh}) Where: * ``W_{ih}`` is input weight. * ``W_{hh}`` is hidden/recurrent weight. * ``h_t`` is the hidden state at time ``t``. * ``x_t`` is the input at time ``t``. * ``h_(t-1)`` is the hidden state of the layer at time ``t-1`` or the initial hidden state at time ``0``. Parameters ---------- x: <s, b, I, T> (Required) * ``s`` is the sequence length, ``b`` is the batch size, and ``I`` is the input dimension. initial_h: <b, H, T> (Required) * ``H`` denotes hidden size. weight_ih: const<H, I, T> (Required) - Input-hidden weight matrix weight_hh: const<H, H, T> (Required) - Hidden-hidden weight matrix bias: const<H, T> (Optional) [Default all 0s] * bias for input-hidden and hidden-hidden direction: const<str> (Optional) [Default=forward] * Either ``forward`` or ``reverse``. output_sequence: const<bool> (Optional) [Default=False] * Outputs every step if ``True``. activation: const<str> (Optional) [Default=tanh] * Supported activation functions: ``relu``, ``tanh``, ``sigmoid``, ``sigmoid_hard``, ``scaled_tanh``, and ``linear``. Returns ------- <s, b, H, T> or <1, b, H, T> * If ``output_sequence == True`` (hidden states from every step): ``<s, b, H, T>``. * Else ``<1, b, H, T>`` (hidden states of the final step). <b, H, T> * Hidden states of the final step. Attributes ---------- T: fp32 """ input_spec = InputSpec( x=TensorInputType(), initial_h=TensorInputType(), weight_ih=TensorInputType(const=True), weight_hh=TensorInputType(const=True), bias=TensorInputType(const=True, optional=True), direction=StringInputType(const=True, optional=True), output_sequence=BoolInputType(const=True, optional=True), activation=StringInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs(bias=None, direction="forward", output_sequence=False, activation="tanh") def __init__(self, **kwargs): super(rnn, self).__init__(**kwargs) def type_inference(self): if self.x.rank != 3: raise ValueError( "Invalid input shape. Expecting Rank 3 input, got {}".format( len(self.x.rank))) sequence_length, batch_size, input_size = self.x.shape if self.weight_ih.rank != 2 or self.weight_hh.rank != 2: raise ValueError( "Invalid weight shape. Expecting Rank 2 input, got weight_ih {}, weight_hh {}" ).format(self.weight_ih.rank, self.weight_hh.rank) hidden_size, _ = self.weight_ih.shape direction = self.direction.val valid_directions = {"forward", "reverse"} if direction not in valid_directions: raise ValueError( "Direction {} not supported. Supported directions: {}".format( direction, valid_directions)) out_seq_len = sequence_length if self.output_sequence.val else 1 output_shape = [out_seq_len, batch_size, hidden_size] output_h_shape = [batch_size, hidden_size] return ( types.tensor(self.x.dtype, tuple(output_shape)), types.tensor(self.x.dtype, tuple(output_h_shape)), )
class concat(Operation): """ Concatenates tensors along a dimension. Parameters ---------- values: Tuple[tensor<[d0, d1, ..., d_axis_i, ..., d_n],T>] (Required) * The number of dimensions of the input tensors must match, and all dimensions except ``axis`` must be equal. * The tensors may be variadic, but the number of tensors must be determined at compile time (i.e. a tuple). axis: const<int32> (Required) * The dimension along which to concatenate. Must be in the range ``[-rank(values[i]), rank(values[i]))`` for all ``i``. interleave: const<bool> (Optional, Default=False) * If true, concatenate the inputs by interleaving them. * If true, all the inputs to this op must have the exact same shape. Examples -------- .. sourcecode:: python in1 : shape (3, 2), value = [[1, 2], [3, 4], [5, 6]] in2 : shape (3, 2), value = [[7, 8], [9, 10], [11, 12]] axis = 0 if interleave = False (default) output : shape (6, 2) output[0:3, :] = in1 output[3:6, :] = in2 value = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]] if interleave = True output : shape (6, 2) output[0::2, :] = in1 output[1::2, :] = in2 value = [[1, 2], [7, 8], [3, 4], [9, 10], [5, 6], [11, 12]] Returns ------- tensor<[d0, d1,...d_axis_out, ..., d_n],T> * Where ``d_axis_out = sum(d_axis_i)``. Attributes ---------- T: fp32, int32 """ input_spec = InputSpec(values=TupleInputType(), axis=IntInputType(const=True), interleave=BoolInputType(const=True, optional=True)) def default_inputs(self): return DefaultInputs( interleave=False, ) def __init__(self, **kwargs): super(concat, self).__init__(**kwargs) def type_inference(self): concat_dim_len = 0 if len(self.values) == 0: raise ValueError("Concat {} got 0 values".format(self.name)) # Validate values have the same rank rank = self.values[0].rank for v in self.values: if v.rank != rank: msg = "Input {} has rank {} != other inputs rank {}" raise ValueError(msg.format(v.name, v.rank, rank)) # Check concat axis is within (-rank, rank) concat_axis = self.axis.val if concat_axis < 0: concat_axis += rank if rank > 0 and (concat_axis < 0 or concat_axis >= rank): msg = "In {} of op_type {}: axis out of bound for input " + "(rank {})" raise ValueError(msg.format(self.name, self.op_type, rank)) # Validate primitive types are compatible dtype = self.values[0].dtype for v in self.values[1:]: new_dtype = promoted_primitive_type(v.dtype, dtype) if new_dtype is None: msg = "Incompatible primitive types concat: {} vs {}" raise ValueError(msg.format(v.dtype, dtype)) dtype = new_dtype # validate that non-axis dimensions match retshape = list(self.values[0].shape) for v in self.values[1:]: for i in range(rank): if is_symbolic(retshape[i]) or is_symbolic(v.shape[i]): continue if i != concat_axis and retshape[i] != v.shape[i]: msg = 'Dimension mismatch in {} ("{}"): shapes {} vs. {}' raise ValueError( msg.format(self.op_type, self.name, retshape, v.shape) ) if self.interleave.val and retshape[i] != v.shape[i]: msg = 'Dimension mismatch in {} ("{}"): shapes {} vs. {}. ' \ 'All inputs must have same shape when \'interleave\' option is True.' raise ValueError( msg.format(self.op_type, self.name, retshape, v.shape) ) # Get length of concat dim concat_dim_len = 0 for v in self.values: if len(v.shape) == 0: taxis = 1 else: taxis = v.shape[concat_axis] if is_symbolic(taxis): concat_dim_len = get_new_symbol() break concat_dim_len += taxis if len(retshape) == 0: retshape = [concat_dim_len] else: retshape[concat_axis] = concat_dim_len return types.tensor(dtype, retshape) @precondition(allow=VALUE | SYMBOL | NONE) def value_inference(self): values = [] for v in self.values: if v.sym_val is not None: values.append(v.sym_val) continue if v.rank == 0: values.append(get_new_symbol()) continue if any_symbolic(v.shape): values.append(None) continue # we support value inference when number of elements for each tensor is less than 10 shape = v.shape num_element = np.prod(shape) if num_element > 10: values.append(None) continue symbolic_tensor = [get_new_symbol() for _ in range(num_element)] symbolic_tensor = np.reshape(np.array(symbolic_tensor), shape) values.append(symbolic_tensor) if any([val is None for val in values]): return None if not isinstance(values[0], np.ndarray) or values[0].shape == (): return np.stack(values, axis=self.axis.val) return np.concatenate(values, axis=self.axis.val)
class cumsum(Operation): """ Returns the cumulative sum of the input along the given axis. Parameters ---------- x: tensor<\*?, T> (Required) * Input tensor. axis: const<i32> (Optional) * default to ``0``. * Axis for which the cumulative sum is computed. exclusive: const<bool> (Optional) * Default to ``False``. * When set to ``False``, inclusive cumsum is computed, that is the first element of the output is identical to the first element in the input. * When set to ``True``, exclusive cumsum is computed, which makes the first element of output to ``0``. reverse: const<bool> (Optional) * Default to ``False``. * When set to ``True``, perform cumsum in the reverse order. Returns ------- tensor<\*?, T> * Same type and shape as the input tensor. Attributes ---------- T: fp32, int32 """ input_spec = InputSpec( x=TensorInputType(), axis=IntInputType(const=True, optional=True), exclusive=BoolInputType(const=True, optional=True), reverse=BoolInputType(const=True, optional=True), ) def default_inputs(self): return DefaultInputs( axis=0, exclusive=False, reverse=False) def __init__(self, **kwargs): super(cumsum, self).__init__(**kwargs) @precondition(allow=VALUE) def value_inference(self): data = np.copy(self.x.val) axis = self.axis.val reverse = self.reverse.val exclusive = self.exclusive.val if reverse: data = np.flip(data, axis=axis) data = np.cumsum(data, axis=axis) if exclusive: zero_shape = np.copy(data.shape) zero_shape[axis] = 1 data = np.concatenate((np.zeros(zero_shape, data)), axis=axis) if reverse: data = np.flip(data, axis=axis) return data def type_inference(self): # Check range of axis if self.axis.val < -1 or self.axis.val > self.x.rank - 1: raise ValueError( "axis should be in the range [-1, {}]".format(self.x.rank - 1) ) return self.x.sym_type
class cond(Operation): """ Perform a conditional execution. The return types must be identical between the true and false branches. Parameters ---------- pred: tensor<[], bool> (Required) * 0-D tensor (scalar) predicate to switch between true and false branches. _true_fn: function (Required) * A Python function that executes if ``pred`` evaluates to ``True``. * It must take zero input (i.e, no input), and return one or more values whose type becomes the operation's return type. _false_fn: function (Required) * A Python function that executes if ``pred`` evaluates to ``False``. * It must take zero input (i.e. no input), and have return types that match those of the ``if`` branch. _existing_blocks: list[Block] (Optional) * Python list of ``Block``. * For internal use only. When converting a milproto, we already got existing blocks, and the ``build_nested_blocks`` function can use them directly. * When ``_existing_blocks`` is set, ``_true_fn`` and ``_false_fn`` must be dummy functions which returns ``None``. Returns ------- tuple * Python tuple of ``Variables`` from one of the branches. """ input_spec = InputSpec( pred=BoolInputType(), _true_fn=PyFunctionInputType(), _false_fn=PyFunctionInputType(), _existing_blocks=InternalInputType(optional=True), ) def __init__(self, **kwargs): super().__init__(**kwargs) def build_nested_blocks(self): # If the front end is milproto, we already have the well constructed cond/body block. # For this case, we set self.blocks directly. # We also check that _cond and _body are both dummy functions (return None). if self._existing_blocks is not None and self._existing_blocks.val is not None: assert self._true_fn.val([]) is None assert self._false_fn.val([]) is None self.blocks = self._existing_blocks.val return # Cond block true_block_name = self.name + "_true" with Block(name=true_block_name, outer_op=self) as true_block: true_func = self._true_fn.val true_ret_vars = true_func() if isinstance(true_ret_vars, tuple): true_ret_vars = list(true_ret_vars) if not isinstance(true_ret_vars, list): true_ret_vars = [true_ret_vars] true_block.set_outputs(true_ret_vars) self.blocks.append(true_block) false_block_name = self.name + "_false" with Block(name=false_block_name, outer_op=self) as false_block: false_func = self._false_fn.val false_ret_vars = false_func() if isinstance(false_ret_vars, tuple): false_ret_vars = list(false_ret_vars) if not isinstance(false_ret_vars, list): false_ret_vars = [false_ret_vars] false_block.set_outputs(false_ret_vars) self.blocks.append(false_block) def type_inference(self): true_ret_vars = self.blocks[0].outputs false_ret_vars = self.blocks[1].outputs # Verify true_ret_vars has the same types as false_ret_vars for i, (vt, vf) in enumerate(zip(true_ret_vars, false_ret_vars)): if not is_compatible_type(vt.sym_type, vf.sym_type): msg = ("true branch output {} type {} mismatch false branch" + " output type {}") raise ValueError( msg.format(vt.name, vt.sym_type.__type_info__(), vf.sym_type.__type_info__())) return tuple(v.sym_type for v in true_ret_vars) def value_inference(self): if self.pred.val is None: raise NotImplementedError() if self.pred.val: return [v.val for v in self.blocks[0].outputs] return [v.val for v in self.blocks[1].outputs]