示例#1
0
def index_copy(g, self, dim, index, source):
    dim_value = sym_help._parse_arg(dim, "i")
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        return g.at("index_copy", self, index, source, dim_i=dim_value)
    expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(
        g, self, dim, index)
    return scatter(g, self, dim, expanded_index, source)
示例#2
0
def index_copy(g, self, dim, index, source):
    dim_value = sym_help._parse_arg(dim, "i")
    if sym_help.is_caffe2_aten_fallback():
        return g.at("index_copy", self, index, source, dim_i=dim_value)
    expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(
        g, self, dim, index)
    return scatter(g, self, dim, expanded_index, source)
示例#3
0
def index_fill(g, self, dim, index, value):
    dim_value = sym_help._parse_arg(dim, 'i')
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill")
    expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
    value = sym_help._maybe_get_scalar(value)
    value = sym_help._if_scalar_type_as(g, value, self)
    expanded_value = expand(g, value, expanded_index_shape, None)
    return scatter(g, self, dim, expanded_index, expanded_value)
示例#4
0
def index_fill(g, self, dim, index, value):
    dim_value = sym_help._parse_arg(dim, "i")
    if sym_help.is_caffe2_aten_fallback():
        return g.at("index_fill",
                    self,
                    index,
                    value,
                    overload_name="int_Scalar",
                    dim_i=dim_value)

    expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(
        g, self, dim, index)
    value = sym_help._maybe_get_scalar(value)
    value = sym_help._if_scalar_type_as(g, value, self)
    expanded_value = expand(g, value, expanded_index_shape, None)
    return scatter(g, self, dim, expanded_index, expanded_value)