def _flatten_helper(g, input, start_dim, end_dim, dim): input_size = g.op("Shape", input) slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim]) slices = [ slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)) ] if end_dim < dim - 1: slice3 = _slice_helper(g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim]) slices = [ slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slice3 ] final_shape = g.op("Concat", *slices, axis_i=0) from torch.onnx.symbolic_opset9 import _reshape_from_tensor return _reshape_from_tensor(g, input, final_shape)
def flatten(g, input, start_dim, end_dim): dim = input.type().dim() # use ONNX's Flatten operator for cases where the output shape is 2D if start_dim == 1: if (end_dim == -1 or (end_dim is not None and end_dim == dim - 1)): return g.op("Flatten", input, axis_i=start_dim) elif start_dim == 0: if (end_dim == -2 or (end_dim is not None and end_dim == dim - 2)): return g.op("Flatten", input, axis_i=end_dim + 1) # use Reshape for cases where the output shape is not 2D if not input.isCompleteTensor(): return _unimplemented( "flatten", "input size not accessible " "(consider using reshape op instead of flatten op to export to ONNX)" ) # if end_dim is negative add dim if end_dim < 0: end_dim = dim + end_dim input_dims = input.type().sizes() output_dims = [] for i in range(0, dim): if start_dim < i and end_dim >= i: output_dims[start_dim] = output_dims[start_dim] * input_dims[i] else: output_dims.append(input_dims[i]) shape = g.op("Constant", value_t=torch.LongTensor(output_dims)) from torch.onnx.symbolic_opset9 import _reshape_from_tensor p = _reshape_from_tensor(g, input, shape) return p
def tensordot(g, input_a, input_b, dims_a, dims_b, out=None): if out is not None: _unimplemented("Tensordot", "Out parameter is not supported for tensordot.") dim_count_a = sym_help._get_tensor_rank(input_a) if dim_count_a is None: raise RuntimeError("Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.") dim_count_b = sym_help._get_tensor_rank(input_b) if dim_count_b is None: raise RuntimeError("Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.") dims_a = [(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] for i in range(len(dims_a))] dims_b = [(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] for i in range(len(dims_b))] left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] new_input_a = permute(g, input_a, left_dims_a + dims_a) new_input_b = permute(g, input_b, dims_b + left_dims_b) input_shape = g.op("Shape", new_input_a) left_sizes_a = sym_help._slice_helper(g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]) shape_sizes = [left_sizes_a, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] output_a = _reshape_from_tensor(g, new_input_a, shape_sizes) input_shape = g.op("Shape", output_a) slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[-1], ends=[maxsize]) shape_sizes = [g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices] output_a = _reshape_from_tensor(g, new_input_a, shape_sizes) input_shape = g.op("Shape", new_input_b) left_sizes_b = sym_help._slice_helper(g, input_shape, axes=[0], starts=[len(dims_b)], ends=[maxsize]) slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]) shape_sizes = [slices, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] output_b = _reshape_from_tensor(g, new_input_b, shape_sizes) input_shape = g.op("Shape", output_b) slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[-1], ends=[maxsize]) shape_sizes = [g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices] output_b = _reshape_from_tensor(g, new_input_b, shape_sizes) output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) shape_sizes = [left_sizes_a, left_sizes_b] return _reshape_from_tensor(g, output, shape_sizes)