Esempio n. 1
0
def default_xop(xop_type: str, xgraph: XGraph, node: Node,
                quant_config: NndctQuantInfo) -> NoReturn:

    input_ops: Dict[str, List["xir.Op"]] = {}
    if node.has_bound_params():
        for param_name, param_tensor in node.op.params.items():
            param = xgraph.get_op_by_name(param_tensor.name)
            input_ops[param_name.name.lower()] = [param]

    input_list = []
    for input in node.in_tensors:
        if node.has_bound_params() and input.is_param_tensor():
            continue
        elif input.is_param_tensor():
            input_op = xgraph.get_op_by_name(input.name)
        else:
            input_op = xgraph.get_op_by_name(input.node.name)
        input_list.append(input_op)

    input_ops["input"] = xgraph.create_input_fix_ops(input_list, node.name,
                                                     quant_config)

    attrs = _get_xir_attr_from_node(node)
    xgraph.create_fixed_normal_op(node.name,
                                  xop_type,
                                  quant_config,
                                  attrs=attrs,
                                  input_ops=input_ops)
Esempio n. 2
0
def conv_transpose_3d(xgraph: XGraph, node: Node,
                      quant_config: NndctQuantInfo) -> NoReturn:
    attrs = _get_xir_attr_from_node(node)
    attrs['kernel'] = attrs['kernel'][::-1]
    attrs['stride'] = attrs['stride'][::-1]
    attrs['dilation'] = attrs['dilation'][::-1]
    attrs['pad'] = list(
        itertools.chain.from_iterable([[pad] * 2
                                       for pad in attrs['pad'][::-1]]))
    print(attrs)
    input_ops: Dict[str, List["xir.Op"]] = {}
    if node.has_bound_params():
        for param_name, param_tensor in node.op.params.items():
            param = xgraph.get_op_by_name(param_tensor.name)
            input_ops[param_name.name.lower()] = [param]

    input_list = []
    for input in node.in_tensors:
        if node.has_bound_params() and input.is_param_tensor():
            continue
        elif input.is_param_tensor():
            input_op = xgraph.get_op_by_name(input.name)
        else:
            input_op = xgraph.get_op_by_name(input.node.name)
        input_list.append(input_op)

    input_ops["input"] = xgraph.create_input_fix_ops(input_list, node.name,
                                                     quant_config)

    xgraph.create_fixed_normal_op(node.name,
                                  "transposed_conv3d",
                                  quant_config,
                                  attrs=attrs,
                                  input_ops=input_ops)
Esempio n. 3
0
 def _get_init_param_str(self, node: Node) -> List[str]:
     str_list = []
     if node.has_bound_params():
         return str_list
     for param_type, param_tensor in node.op.params.items():
         if param_tensor.name not in self._tensor_output_map:
             param_name = param_type.value
             param_shape = tuple(param_tensor.shape)
             param_init_str = f"self.{param_name} = torch.nn.parameter.Parameter(torch.Tensor{param_shape})"
             str_list.append(param_init_str)
             self._tensor_output_map[
                 param_tensor.name] = f"self.{param_name}"
     return str_list