def split(g, self, split_size_or_sizes, dim, _outputs=None): if not sym_help._is_split_static(split_size_or_sizes, _outputs): split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) if _outputs is None: return split_out # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if sym_help._is_packed_list(split_size_or_sizes) and len( sym_help._unpack_list(split_size_or_sizes)) == _outputs: split_sizes = [ sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes) ] start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] for i in range(_outputs): end = g.op( "Add", start, split_sizes[i] ) # split_sizes is a list of same length as _outputs res.append(g.op("Slice", self, start, end, axis)) start = end return res return [ g.op("SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long))) for i in range(_outputs) ] else: return torch.onnx.symbolic_opset9.split(g, self, split_size_or_sizes, dim, _outputs)
def split(g, self, split_size_or_sizes, dim, _outputs=None): if not sym_help._is_split_static(split_size_or_sizes, _outputs): split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) if _outputs is None: return split_out # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if sym_help._is_packed_list(split_size_or_sizes) and \ len(sym_help._unpack_list(split_size_or_sizes)) == _outputs: split_sizes = [sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes)] start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] for i in range(_outputs): end = g.op("Add", start, split_sizes[i]) # split_sizes is a list of same length as _outputs res.append(g.op("Slice", self, start, end, axis)) start = end return res return [g.op("SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long))) for i in range(_outputs)] split_val = split_size_or_sizes.node()['value'] if split_val.dim() > 0: return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) split_size = sym_help._get_const(split_size_or_sizes, 'i', 'split_size') size = self.type().sizes()[dim] splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) splits = g.op("Constant", value_t=torch.tensor(splits)) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
def split(g, self, split_size_or_sizes, dim, _outputs=None): if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) if _outputs is None: return split_out # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if (symbolic_helper._is_packed_list(split_size_or_sizes) and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs): split_sizes = [ symbolic_helper._unsqueeze_helper(g, v, [0]) for v in symbolic_helper._unpack_list(split_size_or_sizes) ] start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] for i in range(_outputs): end = g.op( "Add", start, split_sizes[i] ) # split_sizes is a list of same length as _outputs res.append(g.op("Slice", self, start, end, axis)) start = end return res return [ g.op( "SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), ) for i in range(_outputs) ] split_val = split_size_or_sizes.node()["value"] if split_val.dim() > 0: return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: if _outputs is not None: size = split_size * _outputs else: raise RuntimeError("Unknown dimension size not supported") splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) splits = g.op("Constant", value_t=torch.tensor(splits)) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
def split(g, self, split_size_or_sizes, dim, _outputs=None): if not sym_help._is_split_static(split_size_or_sizes, _outputs): split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) if _outputs is None: return split_out return [g.op("SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long))) for i in range(_outputs)] else: return torch.onnx.symbolic_opset9.split(g, self, split_size_or_sizes, dim, _outputs)
def tensor_split(g, self, indices_or_sections, dim, _outputs=None): axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) axis = opset11.unsqueeze(g, axis, 0) const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) if symbolic_helper._is_split_static(indices_or_sections, _outputs): split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") if split_val.dim() > 0: start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) res = [] assert _outputs is not None for i in range(_outputs - 1): end = g.op( "Gather", indices_or_sections, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), axis_i=0, ) res.append(g.op("Slice", self, start, end, axis)) start = end end = symbolic_helper._size_helper(g, self, axis) res.append(g.op("Slice", self, start, end, axis)) return res split_size = symbolic_helper._get_const( indices_or_sections, "i", "indices_or_sections" ) size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: if _outputs is not None: size = split_size * _outputs else: raise errors.SymbolicValueError( "Unknown dimension size not supported", self ) min_split_size = size // split_size num_splits_one_extra = size % split_size splits = num_splits_one_extra * [min_split_size + 1] leftover = (split_size - num_splits_one_extra) * [min_split_size] splits = g.op( "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) ) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) if ( symbolic_helper._is_tensor(indices_or_sections) and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 ): loop_len = symbolic_helper._size_helper( g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) ) loop_len = opset11.unsqueeze(g, loop_len, 0) loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) # To make the first slice in the below loop work, # we pad a zero to the first position so that it will be the initial start of slice. padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) final_splits = g.op("SequenceEmpty") loop = g.op("Loop", loop_len, loop_condition, final_splits) # Loop inputs loop_block = utils._add_block(loop.node()) block_input_iter = utils._add_input_to_block(loop_block) cond = utils._add_input_to_block(loop_block) final_splits = utils._add_input_to_block(loop_block) start = loop_block.op("Gather", indices_or_sections, block_input_iter, axis_i=0) end = loop_block.op( "Gather", indices_or_sections, loop_block.op("Add", block_input_iter, const_1), axis_i=0, ) slice = loop_block.op("Slice", self, start, end, axis) final_splits = loop_block.op("SequenceInsert", final_splits, slice) # Loop outputs cond_out = loop_block.op("Identity", loop_condition) utils._add_output_to_block(loop_block, cond_out) utils._add_output_to_block(loop_block, final_splits) loop_out = loop.node().output() start = g.op( "Gather", indices_or_sections, g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), axis_i=0, ) start = opset11.unsqueeze(g, start, 0) end = symbolic_helper._size_helper(g, self, axis) last_slice = g.op("Slice", self, start, end, axis) return g.op("SequenceInsert", loop_out, last_slice) else: # scalar tensor dim_size = symbolic_helper._size_helper(g, self, axis) min_split_size = g.op("Div", dim_size, indices_or_sections) min_split_size_plus_1 = g.op( "Add", min_split_size, const_1, ) num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) leftover = g.op( "Tile", min_split_size, g.op( "Sub", opset11.unsqueeze(g, indices_or_sections, 0), num_splits_one_extra, ), ) splits = g.op("Concat", splits, leftover, axis_i=0) if _outputs is None: return g.op("SplitToSequence", self, splits, axis_i=dim) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)