def _get_im2col_output_shape(g, input, kernel_h, kernel_w): batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) channel_unfolded = g.op("Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w))) return g.op("Concat", g.op("Unsqueeze", batch_dim, axes_i=[0]), g.op("Unsqueeze", channel_unfolded, axes_i=[0]), g.op("Constant", value_t=torch.tensor([-1])), axis_i=0)
def im2col(g, input, kernel_size, dilation, padding, stride): # Input is always 4-D tensor (N, C, H, W) # All other args are int[2] input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) stride_h, stride_w = stride[0], stride[1] padding_h, padding_w = padding[0], padding[1] dilation_h, dilation_w = dilation[0], dilation[1] kernel_h, kernel_w = kernel_size[0], kernel_size[1] blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h, dilation_h, padding_h, stride_h) blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w, dilation_w, padding_w, stride_w) output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 # [[[[1., 2., 3.,], # [4., 5., 6.,], # [7., 8., 9.,]]]] # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: # [[[[[1., 2., 3.], # [4., 5., 6.]], # [[4., 5., 6.], # [7., 8., 9.]]]]] # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: # [[[[[[1., 2.], # [4., 5.]], # [[2., 3.], # [5., 6]]], # [[[4., 5.], # [7., 8.]], # [[5., 6.], # [8., 9.]]]]]] # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: # [[[1., 2., 4., 5.], # [2., 3., 5., 6.], # [4., 5., 7., 8.], # [5., 6., 8., 9.]]] output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) output = g.op("Gather", output, blocks_col_indices, axis_i=4) output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) return g.op("Reshape", output, output_shape)
def reverse(x): from torch.onnx.symbolic_opset9 import reshape, transpose, size y = transpose(g, x, 0, dim) shape = g.op("Shape", y) y = reshape(g, y, [0, 1, -1]) n = size(g, y, g.op("Constant", value_t=torch.LongTensor([0]))) y = g.op("ReverseSequence", y, n, batch_axis_i=1, time_axis_i=0) y = reshape(g, y, shape) y = transpose(g, y, 0, dim) return y
def masked_scatter(g, self, mask, source): from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size index = nonzero(g, expand_as(g, mask, self)) # NOTE: source can have more elements than needed. # It could also have arbitrary shape. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. source = view(g, source, torch.LongTensor([-1])) source = sym_help._slice_helper(g, source, axes=torch.LongTensor([0]), starts=torch.LongTensor([0]), ends=size(g, index, torch.LongTensor([0])), dynamic_slice=True) return g.op('ScatterND', self, index, source)
def _prepare_onnx_paddings(g, input, pad): if ( not symbolic_helper._is_packed_list(pad) and symbolic_helper._is_list(pad) and symbolic_helper._is_scalar_list(pad) ): pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) # Set extension = [0] * (dim * 2 - len(pad)) rank = symbolic_helper._get_tensor_rank(input) if rank is None: rank = g.op("Size", g.op("Shape", input)) else: rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) extension = g.op( "Sub", g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len, ) # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] # Currently ONNX only supports int64 type for Pad pad = g.op("Cast", pad, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]) paddings = g.op( "Concat", pad, g.op( "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) ), axis_i=0, ) # Reshape and reverse order and collate first beginnings and then ends # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], # [..., 0, dim_n-1_end, dim_n_end]] # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] paddings = symbolic_helper._reshape_helper( g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) ) paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) paddings = symbolic_helper._reshape_helper( g, paddings, g.op("Constant", value_t=torch.tensor([-1])) ) padding_c = g.op( "Cast", paddings, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"] ) return padding_c
def masked_scatter(g, self, mask, source): index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) # NOTE: source can have more elements than needed. # It could also have arbitrary shape. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) source = symbolic_helper._slice_helper( g, source, axes=torch.LongTensor([0]), starts=torch.LongTensor([0]), ends=opset9.size(g, index, torch.LongTensor([0])), dynamic_slice=True, ) return g.op("ScatterND", self, index, source)
def _len(g, self): if _is_tensor_list(self) or self.node().kind() == "onnx::SplitToSequence": return g.op("SequenceLength", self) sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) return sym_help._squeeze_helper(g, sz_0, [0])
def diagonal(g, self, offset, dim1, dim2): dim1_size = opset9.size( g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) ) dim2_size = opset9.size( g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) ) # Create appropriate mask mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) mask = opset9.zeros(g, mask_shape, None, None, None) mask = g.op("EyeLike", mask, k_i=offset) # dim1 and dim2 appended as a dimension at the end of the shape rank = symbolic_helper._get_tensor_rank(self) if rank is not None: axes = list(range(rank)) axes.remove(dim1) axes.remove(dim2) self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) else: return symbolic_helper._unimplemented("diagonal", "unknown input rank") # Multiply input and mask to calculate values along diagonal # The mask consists of one values where diagonal values are to be calculated # For example: # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] result = g.op("Mul", self, mask) result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) # Calculate gather indices based on offset and dims # If offset is greater than zero, set offset to zero as this aids in # calculation of selection window offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) if offset >= 0: diag_size = g.op( "Max", g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), g.op("Constant", value_t=torch.LongTensor([0])), ) offset = 0 else: diag_size = g.op( "Max", g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), g.op("Constant", value_t=torch.LongTensor([0])), ) diag_size = g.op("Concat", diag_size, axis_i=0) # Calculate which diagonal values to select # For example, in cases with offsets: # [[0, 1.1, 0] # [0, 0, 2.2]] # we need to select the last two columns, so we create a tensor # with all columns that are to be selected # So in this example, it is [1, 2] select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) select_window = g.op( "CumSum", select_window_ones_fill, g.op("Constant", value_t=torch.LongTensor([0])), ) select_window = g.op( "Add", select_window, g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), ) gather_shape = [ opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) for axis in list(range(rank))[:-2] ] gather_shape.append(diag_size) gather_shape = g.op("Concat", *gather_shape, axis_i=0) gather_indices = opset9.zeros(g, gather_shape, 4, None, None) # There might be cases where offset value is greater than number of rows/columns # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. # For example, if # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 # In cases without diagonal overrun, we select the appropriate rows/columns along which we # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially # returning an empty tensor overrun_cond = g.op( "Not", g.op( "Equal", diag_size, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), ), ) if_op = g.op("If", overrun_cond) if_node = if_op.node() if_block = utils._add_block(if_node) gather_indices_if_block = if_block.op("Add", gather_indices, select_window) gather_indices_if_block = symbolic_helper._unsqueeze_helper( if_block, gather_indices_if_block, [rank - 1] ) final_non_overrun_ = if_block.op( "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 ) utils._add_output_to_block(if_block, final_non_overrun_) else_block = utils._add_block(if_node) final_overrun_ = opset9.zeros(else_block, gather_shape, 6, None, None) utils._add_output_to_block(else_block, final_overrun_) return if_op
def embedding_bag(g, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset): if scale_grad_by_freq and sym_help._training_mode: return sym_help._onnx_unsupported( 'embedding_bag with scale_grad_by_freq for training mode') from torch.onnx.symbolic_opset9 import size, div, select # Check if initial indices was 2D. In functional.py: # offsets is set to torch.arange(0, indices.numel(), indices.size(1)) # Then indices is reshaped to 1D: indices.reshape(-1) if len(list(indices.node().inputs())) > 0 and indices.node().inputs().__next__().type().sizes() is not None \ and len(indices.node().inputs().__next__().type().sizes()) == 2: # Assert include_last_offset is False assert not include_last_offset embeddings = g.op("Gather", embedding_matrix, indices) dim_0 = size(g, offsets, g.op("Constant", value_t=torch.LongTensor([0]))) dim_1 = div( g, size(g, indices, g.op("Constant", value_t=torch.LongTensor([0]))), dim_0) dim_2 = g.op("Constant", value_t=torch.LongTensor([-1])) shape = [dim_0, dim_1, dim_2] shape = g.op("Concat", *shape, axis_i=0) if not sym_help._is_none(per_sample_weights): per_sample_weights = g.op("Unsqueeze", per_sample_weights, axes_i=[1]) embeddings = g.op("Mul", embeddings, per_sample_weights) embeddings = g.op("Reshape", embeddings, shape) if mode == 0: embeddings = g.op("ReduceSum", embeddings, axes_i=[1], keepdims_i=0) elif mode == 1: embeddings = g.op("ReduceMean", embeddings, axes_i=[1], keepdims_i=0) else: embeddings = g.op("ReduceMax", embeddings, axes_i=[1], keepdims_i=0) # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. return embeddings, None, None, None elif offsets.type().sizes() is not None: if include_last_offset: offset_len = offsets.type().sizes()[0] - 1 offsets_extended = offsets else: offset_len = offsets.type().sizes()[0] offsets_extended = [ offsets, g.op("Constant", value_t=torch.tensor([maxsize])) ] offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) list_ = [] for i in range(offset_len): start_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), axes_i=[0]) end_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), axes_i=[0]) axes_ = g.op("Constant", value_t=torch.tensor([0])) indices_row = g.op("Slice", indices, start_, end_, axes_) embeddings = g.op("Gather", embedding_matrix, indices_row) if not sym_help._is_none(per_sample_weights): per_sample_weights_row = g.op("Slice", per_sample_weights, start_, end_, axes_) per_sample_weights_row = g.op("Unsqueeze", per_sample_weights_row, axes_i=[1]) embeddings = g.op("Mul", embeddings, per_sample_weights_row) if mode == 0: embeddings = g.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0) elif mode == 1: embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) else: embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) embeddings = g.op("Unsqueeze", embeddings, axes_i=[0]) list_.append(embeddings) output = g.op("Concat", *list_, axis_i=0) # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. return output, None, None, None else: return sym_help._onnx_unsupported( 'embedding_bag with unknown shape of indices')