def index_copy(g, self, dim, index, source): dim_value = sym_help._parse_arg(dim, "i") if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper( g, self, dim, index) return scatter(g, self, dim, expanded_index, source)
def index_copy(g, self, dim, index, source): dim_value = sym_help._parse_arg(dim, "i") if sym_help.is_caffe2_aten_fallback(): return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper( g, self, dim, index) return scatter(g, self, dim, expanded_index, source)
def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, 'i') if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill") expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index) value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, self) expanded_value = expand(g, value, expanded_index_shape, None) return scatter(g, self, dim, expanded_index, expanded_value)
def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, "i") if sym_help.is_caffe2_aten_fallback(): return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value) expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper( g, self, dim, index) value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, self) expanded_value = expand(g, value, expanded_index_shape, None) return scatter(g, self, dim, expanded_index, expanded_value)