def frobenius_norm(g, self, dim=None, keepdim=False): dim_val = sym_help._maybe_get_const(dim, "is") if not sym_help._is_value(dim_val) and len(dim_val) == 0: return g.op("ReduceL2", self, keepdims_i=0) sqr = g.op("Mul", self, self) sumsqr = sym_help._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) return g.op("Sqrt", sumsqr)
def embedding_bag(g, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx): if scale_grad_by_freq and sym_help._training_mode: return sym_help._onnx_unsupported("embedding_bag with scale_grad_by_freq for training mode") if padding_idx is not None and padding_idx >= 0: raise RuntimeError("embedding_bag with padding_idx") from torch.onnx.symbolic_opset9 import select import warnings warnings.warn("Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " "Please use opset 11 or higher to export model for dynamic input shape.'") offsets_dim_0 = sym_help._get_tensor_dim_size(offsets, 0) if offsets_dim_0 is not None: if include_last_offset: offset_len = offsets_dim_0 - 1 offsets_extended = offsets else: offset_len = offsets_dim_0 offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))] offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) list_ = [] for i in range(offset_len): start_ = sym_help._unsqueeze_helper(g, select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), [0]) end_ = sym_help._unsqueeze_helper(g, select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), [0]) axes_ = g.op("Constant", value_t=torch.tensor([0])) indices_row = g.op("Slice", indices, start_, end_, axes_) embeddings = g.op("Gather", embedding_matrix, indices_row) if not sym_help._is_none(per_sample_weights): per_sample_weights_row = g.op("Slice", per_sample_weights, start_, end_, axes_) per_sample_weights_row = sym_help._unsqueeze_helper(g, per_sample_weights_row, [1]) embeddings = g.op("Mul", embeddings, per_sample_weights_row) if mode == 0: embeddings = sym_help._reducesum_helper(g, embeddings, axes_i=[0], keepdims_i=0) elif mode == 1: embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) else: embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) embeddings = sym_help._unsqueeze_helper(g, embeddings, [0]) list_.append(embeddings) output = g.op("Concat", *list_, axis_i=0) # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. return output, None, None, None else: return sym_help._onnx_unsupported("embedding_bag with unknown shape of offsets for opset 10 is not supported. " "please use opset 11 or higher.")
def linalg_vector_norm(g, self, ord, dim, keepdim, dtype): if ord == 0: if dim is None: self = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) keepdim = None cond_op = g.op("Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))) cond_op = g.op("Cast", cond_op, to_i=sym_help.cast_pytorch_to_onnx["Long"]) return sym_help._reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim) else: return lvn(g, self, ord, dim, keepdim, dtype)
def linalg_vector_norm(g, self, ord, dim, keepdim, dtype): if ord == 0: if dim is None: self = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) ) keepdim = 0 cond_op = g.op( "Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))) ) cond_op = g.op( "Cast", cond_op, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) return symbolic_helper._reducesum_helper( g, cond_op, axes_i=dim, keepdims_i=keepdim ) else: return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype)
def embedding_bag(g, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx): if scale_grad_by_freq and sym_help._training_mode: return sym_help._onnx_unsupported( 'embedding_bag with scale_grad_by_freq for training mode') if padding_idx is not None and padding_idx >= 0: raise RuntimeError('embedding_bag with padding_idx') loop_condition = g.op("Constant", value_t=torch.tensor(1)) loop_condition = g.op("Cast", loop_condition, to_i=9) zero = g.op("Constant", value_t=torch.tensor([0])) indices_len = sym_help._unsqueeze_helper( g, sym_help._size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), [0]) if not include_last_offset: offsets = [offsets, indices_len] offsets = g.op("Concat", *offsets, axis_i=0) # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. offsets_starts = sym_help._slice_helper(g, offsets, axes=[0], starts=[0], ends=[maxsize], steps=[1]) offsets_ends = sym_help._slice_helper(g, offsets, axes=[0], starts=[1], ends=[maxsize], steps=[1]) loop_len = sym_help._size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) loop = g.op("Loop", loop_len, loop_condition) loop_block = _add_block(loop.node()) block_input_iter = _add_input_to_block(loop_block) cond = _add_input_to_block(loop_block) indices_start = loop_block.op("Gather", offsets_starts, block_input_iter, axis_i=0) indices_end = loop_block.op("Gather", offsets_ends, block_input_iter, axis_i=0) indices_start = sym_help._unsqueeze_helper(loop_block, indices_start, [0]) indices_end = sym_help._unsqueeze_helper(loop_block, indices_end, [0]) indices_row = loop_block.op("Slice", indices, indices_start, indices_end, zero) embeddings = loop_block.op("Gather", embedding_matrix, indices_row, axis_i=0) if not sym_help._is_none(per_sample_weights): per_sample_weights_row = loop_block.op("Slice", per_sample_weights, indices_start, indices_end, zero) per_sample_weights_row = sym_help._unsqueeze_helper( loop_block, per_sample_weights_row, [1]) embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row) if mode == 0: embeddings = sym_help._reducesum_helper(loop_block, embeddings, axes_i=[0], keepdims_i=0) elif mode == 1: embeddings = loop_block.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) else: embeddings = loop_block.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) cond_out = loop_block.op("Cast", loop_condition, to_i=9) _add_output_to_block(loop_block, cond_out) _add_output_to_block(loop_block, embeddings) # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. return loop.node().output(), None, None, None
def einsum(g, equation, tensor_list): tensors = sym_help._unpack_list(tensor_list) num_ops = len(tensors) assert num_ops > 0 # Doesn't support implicit output is ellipsis or more than 2 oprands for now. # Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands. if num_ops != 2 or equation.find("->") == -1 or "." in equation: return g.op("Einsum", *tensors, equation_s=equation) # Take "ks,ksm->sm" as example. After prcoess inputs, # lhs_labels = [k,s], rhs_labels = [k,s,m], result_labels = [s,m]. lhs_labels, rhs_labels, result_labels = parse_equation(equation) # Doesn't support repeated label in operand for now as it needs to take extra diagonal. if len(lhs_labels) != len(set(lhs_labels)) or len(rhs_labels) != len( set(rhs_labels)): return g.op("Einsum", *tensors, equation_s=equation) # Add contraction labels (labels not present in output). # After process contraction labels, contraction_labels = [k], # label_perm_map = {(s, 0), (m, 1), (k, 2)}, out_size = 2, perm_size = 3. out_size = len(result_labels) label_perm_map = dict([(label, idx) for idx, label in enumerate(result_labels)]) perm_size = out_size contraction_labels = [] lhs_reduce_sum_axes = [] rhs_reduce_sum_axes = [] for label in lhs_labels + rhs_labels: if label not in label_perm_map: if label in lhs_labels and label in rhs_labels: label_perm_map[label] = perm_size contraction_labels.append(label) perm_size += 1 elif label in lhs_labels: lhs_reduce_sum_axes.append(lhs_labels.index(label)) else: rhs_reduce_sum_axes.append(rhs_labels.index(label)) lhs_tensor = tensors[0] rhs_tensor = tensors[1] # If lhs_reduce_sum_axes/rhs_reduce_sum_axes is not empty, ReduceSum on that axes, update lhs_labels/rhs_labels, # and use the output as original_lhs_tensor/original_rhs_tensor. if lhs_reduce_sum_axes: lhs_tensor = sym_help._reducesum_helper(g, lhs_tensor, lhs_reduce_sum_axes, keepdims_i=False) lhs_labels = [ lhs_labels[axis] for axis in range(len(lhs_labels)) if axis not in lhs_reduce_sum_axes ] if rhs_reduce_sum_axes: rhs_tensor = sym_help._reducesum_helper(g, rhs_tensor, rhs_reduce_sum_axes, keepdims_i=False) rhs_labels = [ rhs_labels[axis] for axis in range(len(rhs_labels)) if axis not in rhs_reduce_sum_axes ] # Need to unsqueeze and permute the inputs to order of output with contraction labels. # lhs_perm = [1,2,0], lhs_unsqueeze_axes = [2]. # rhs_perm = [1,2,0], rhs_unsqueeze_axes = []. lhs_perm, lhs_unsqueeze_axes = map_labels_to_output( lhs_labels, label_perm_map) rhs_perm, rhs_unsqueeze_axes = map_labels_to_output( rhs_labels, label_perm_map) # If there is no contraction labels, unsqueeze and permute the inputs and Mul them to get final result. if not contraction_labels: lhs_tensor = unsqueeze_and_permute_for_mul(g, lhs_tensor, lhs_unsqueeze_axes, lhs_perm) rhs_tensor = unsqueeze_and_permute_for_mul(g, rhs_tensor, rhs_unsqueeze_axes, rhs_perm) return g.op("Mul", lhs_tensor, rhs_tensor) # If contraction_labels is not empty, need a BatchedMatMul. # Batched labels are those in all inputs and output. Below axes are based on output. # batched_labels = [s], batched_axes = [0] for the example. # Matmul output labels are those in one of inputs and output. # matmul_output_labels = [m], matmul_output_axes = [1] for the example. # contraction_labels = [k], contraction_axes = [2] for the example. batched_axes = [] matmul_output_axes = [] contraction_axes = [axis for axis in range(out_size, perm_size)] for axis in range(out_size): label = result_labels[axis] if label in lhs_labels and label in rhs_labels: batched_axes.append(axis) else: matmul_output_axes.append(axis) # Based on above unsqueeze and permute on inputs, need to permute again. # For lhs input, the new permute is batched_axes + matmul_output_axes + contraction_axes: [0, 1, 2], # i.e., a.unsqueeze([2]).permute([1,2,0]).permute([0,1,2]) = [s,1,k] for the example. # For rhs input, the new permute is batched_axes + contraction_axes + matmul_output_axes: [0, 2, 1]. # i.e., b.unsqueeze([]).permute([1,2,0]).permute([0,2,1]) = [s,k,m] for the example. lhs_perm = combine_unsqueeze_and_permute_for_matmul( lhs_unsqueeze_axes, lhs_perm, batched_axes + matmul_output_axes + contraction_axes) rhs_perm = combine_unsqueeze_and_permute_for_matmul( rhs_unsqueeze_axes, rhs_perm, batched_axes + contraction_axes + matmul_output_axes) # Need to Reshape two input tensors before the BatchedMatMul and Reshape result to output shape. # Reshape lhs input to [[batched_shapes], Mul(lhs_matmul_output_shapes), Mul(contraction_shapes)]. # Reshape rhs input to [[batched_shapes], Mul(contraction_shapes), Mul(rhs_matmul_output_shapes)] # Convert all axes based on inputs. # lhs_contraction_axes = [0], rhs_contraction_axes = [0], lhs_matmul_output_axes = [], rhs_matmul_output_axes = [2] for the example. lhs_contraction_axes = [ lhs_labels.index(label) for label in contraction_labels ] rhs_contraction_axes = [ rhs_labels.index(label) for label in contraction_labels ] lhs_matmul_output_axes = [ lhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in lhs_labels ] rhs_matmul_output_axes = [ rhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in rhs_labels ] # Caches of input shape tensors to avoid generating duplicated graph. lhs_shape_tensor = None rhs_shape_tensor = None # contraction_numel_tensor should be tensor([size(k)]) for the example, but since length is 1, it's None here. contraction_numel_tensor = None if len(lhs_contraction_axes) > 1: _, contraction_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes( g, lhs_tensor, lhs_shape_tensor, lhs_contraction_axes, True) # Prepare some shape tensors for Reshape if needed. # Both lhs_matmul_output_shape_tensor and lhs_matmul_output_numel_tensor is None for the example. lhs_matmul_output_shape_tensor = None lhs_matmul_output_numel_tensor = None if len(lhs_matmul_output_axes) > 1: lhs_matmul_output_shape_tensor, lhs_matmul_output_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes( g, lhs_tensor, lhs_shape_tensor, lhs_matmul_output_axes, True) # Both rhs_matmul_output_shape_tensor and rhs_matmul_output_numel_tensor is None for the example. rhs_matmul_output_shape_tensor = None rhs_matmul_output_numel_tensor = None if len(rhs_matmul_output_axes) > 1: rhs_matmul_output_shape_tensor, rhs_matmul_output_numel_tensor, rhs_shape_tensor = get_shape_tensor_by_axes( g, rhs_tensor, rhs_shape_tensor, rhs_matmul_output_axes, True) new_lhs_tensor = lhs_tensor # Need to Reshape lhs_tensor if lhs_matmul_output_axes or lhs_contraction_axes is not 1, otherwise permute it directly. # Need to Reshape the lhs_tensor for the example, the new shape is [size(s), 1, size(k)]. if len(lhs_matmul_output_axes) != 1 or len(lhs_contraction_axes) != 1: new_lhs_tensor, lhs_shape_tensor = permute_and_reshape_tensor( g, lhs_tensor, True, len(lhs_labels), lhs_perm, lhs_matmul_output_axes, lhs_contraction_axes, len(batched_axes), lhs_matmul_output_numel_tensor, contraction_numel_tensor, lhs_shape_tensor, ) else: if need_permute(lhs_perm): new_lhs_tensor = g.op("Transpose", lhs_tensor, perm_i=lhs_perm) # Need to Reshape rhs_tensor if rhs_matmul_output_axes or rhs_contraction_axes is not 1, otherwise permute it directly. # rhs_tensor's new shape should be [size(s), size(k), size(m)], but doesn't need to Reshape for the example. new_rhs_tensor = rhs_tensor if len(rhs_matmul_output_axes) != 1 or len(rhs_contraction_axes) != 1: new_rhs_tensor, rhs_shape_tensor = permute_and_reshape_tensor( g, rhs_tensor, False, len(rhs_labels), rhs_perm, rhs_matmul_output_axes, rhs_contraction_axes, len(batched_axes), rhs_matmul_output_numel_tensor, contraction_numel_tensor, rhs_shape_tensor, ) else: if need_permute(rhs_perm): new_rhs_tensor = g.op("Transpose", rhs_tensor, perm_i=rhs_perm) # Perform final BatchedMatMul. Output is shape [size(s), 1, size(m)] for the example. result = g.op("MatMul", new_lhs_tensor, new_rhs_tensor) # Need to Reshape the result if lhs_matmul_output_axes or rhs_matmul_output_axes is not 1. # Need to Reshape the result for the example, the new shape is [size(s), size(m)]. if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1: shape_tensors = [ g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) ] * len(batched_axes) last_zero_dim = len(shape_tensors) - 1 has_neg_one_dim = False if lhs_matmul_output_axes: if len(lhs_matmul_output_axes) == 1: shape_tensors.append( g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))) last_zero_dim = len(shape_tensors) - 1 else: shape_tensors.append(lhs_matmul_output_shape_tensor) if rhs_matmul_output_axes: if len(rhs_matmul_output_axes) == 1: shape_tensors.append( g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) has_neg_one_dim = True else: shape_tensors.append(rhs_matmul_output_shape_tensor) if not has_neg_one_dim and last_zero_dim >= 0: shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor( [-1], dtype=torch.int64)) result = reshape_tensor(g, result, shape_tensors) # Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes], # if this is not same as output, need one permute. labels = ([result_labels[axis] for axis in batched_axes] + [lhs_labels[axis] for axis in lhs_matmul_output_axes] + [rhs_labels[axis] for axis in rhs_matmul_output_axes]) assert len(labels) == out_size output_perm = [labels.index(label) for label in result_labels] assert all(axis in output_perm for axis in range(out_size)) if need_permute(output_perm): result = g.op("Transpose", result, perm_i=output_perm) return result
def diagonal(g, self, offset, dim1, dim2): dim1_size = opset9.size( g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) ) dim2_size = opset9.size( g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) ) # Create appropriate mask mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) mask = opset9.zeros(g, mask_shape, None, None, None) mask = g.op("EyeLike", mask, k_i=offset) # dim1 and dim2 appended as a dimension at the end of the shape rank = symbolic_helper._get_tensor_rank(self) if rank is not None: axes = list(range(rank)) axes.remove(dim1) axes.remove(dim2) self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) else: return symbolic_helper._unimplemented("diagonal", "unknown input rank") # Multiply input and mask to calculate values along diagonal # The mask consists of one values where diagonal values are to be calculated # For example: # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] result = g.op("Mul", self, mask) result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) # Calculate gather indices based on offset and dims # If offset is greater than zero, set offset to zero as this aids in # calculation of selection window offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) if offset >= 0: diag_size = g.op( "Max", g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), g.op("Constant", value_t=torch.LongTensor([0])), ) offset = 0 else: diag_size = g.op( "Max", g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), g.op("Constant", value_t=torch.LongTensor([0])), ) diag_size = g.op("Concat", diag_size, axis_i=0) # Calculate which diagonal values to select # For example, in cases with offsets: # [[0, 1.1, 0] # [0, 0, 2.2]] # we need to select the last two columns, so we create a tensor # with all columns that are to be selected # So in this example, it is [1, 2] select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) select_window = g.op( "CumSum", select_window_ones_fill, g.op("Constant", value_t=torch.LongTensor([0])), ) select_window = g.op( "Add", select_window, g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), ) gather_shape = [ opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) for axis in list(range(rank))[:-2] ] gather_shape.append(diag_size) gather_shape = g.op("Concat", *gather_shape, axis_i=0) gather_indices = opset9.zeros(g, gather_shape, 4, None, None) # There might be cases where offset value is greater than number of rows/columns # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. # For example, if # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 # In cases without diagonal overrun, we select the appropriate rows/columns along which we # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially # returning an empty tensor overrun_cond = g.op( "Not", g.op( "Equal", diag_size, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), ), ) if_op = g.op("If", overrun_cond) if_node = if_op.node() if_block = utils._add_block(if_node) gather_indices_if_block = if_block.op("Add", gather_indices, select_window) gather_indices_if_block = symbolic_helper._unsqueeze_helper( if_block, gather_indices_if_block, [rank - 1] ) final_non_overrun_ = if_block.op( "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 ) utils._add_output_to_block(if_block, final_non_overrun_) else_block = utils._add_block(if_node) final_overrun_ = opset9.zeros(else_block, gather_shape, 6, None, None) utils._add_output_to_block(else_block, final_overrun_) return if_op