Ejemplo n.º 1
0
def _unify_attrs(attrs, op_proto):
    # type: (typing.Dict[str, typing.Any], OpProto)-> typing.Dict[str, typing.Any]
    dict2 = {}
    for arg_proto in op_proto.list_nontensor_arg_protos():
        val = attrs[arg_proto.primary_arg_name]
        if arg_proto.is_array and val is not None:
            val = utils.listify(val)
        dict2[arg_proto.primary_arg_name] = val
    return dict2
Ejemplo n.º 2
0
def generic_convert_reduce(op, target_name):
    # type: (TFOperation, str)->None

    op.name = target_name
    assert op.inputs[
        1].data is not None, "{} is only supported with constant axes".format(
            op.name)
    axes = utils.listify(op.inputs[1].data.tolist())
    op.inputs = (op.inputs[0], )
    op.attribs = dict(axis=axes, keepdims=op.attribs["keep_dims"])
Ejemplo n.º 3
0
    def _match(self, value, settings):
        assert self not in settings.dict_so_far, "Operation cannot be matched multiple times"

        assert isinstance(value, BaseOperation)
        op = value

        if not settings.allow_multi_consumer and any(
                len(r.consumers) > 1 for r in op.outputs):
            return Match()

        if self.name is not None and op.name not in utils.listify(self.name):
            return Match()

        match_ = Match(True, root=op, dict_={self: op})

        if self.inputs is not None:
            match2 = Match()
            for input_patterns in self._pattern_list_list(
                    self.inputs, op.inputs):
                match2 = self._match_inputs(op, settings, input_patterns)
                if match2:
                    break
            if not match2:
                return Match()
            match_ = Match(True,
                           root=op,
                           dict_=utils.dict_union(match_.dict, match2.dict))

        if self.attribs is not None:
            assert isinstance(self.attribs, dict)
            match2 = self._match_attribs(op, settings, self.attribs)
            if not match2:
                return Match()
            match_ = Match(True,
                           root=op,
                           dict_=utils.dict_union(match_.dict, match2.dict))

        if self.outputs is not None:
            match2 = Match()
            for output_patterns in self._pattern_list_list(
                    self.outputs, op.outputs):
                match2 = self._match_outputs(op, settings, output_patterns)
                if match2:
                    break
            if not match2:
                return Match()
            match_ = Match(True,
                           root=op,
                           dict_=utils.dict_union(match_.dict, match2.dict))

        return match_
Ejemplo n.º 4
0
def _get_arg(op_name, args, arg_proto):
    # type: (str, typing.Dict[str, typing.Any], ArgProto)->typing.Any
    found = True
    value = None
    for arg_name in arg_proto.arg_names:
        if arg_name in args:
            found = True
            if args[arg_name] is not None:
                assert value is None
                if arg_proto.is_array and args[arg_name] is not None:
                    value = utils.listify(args[arg_name])
                else:
                    value = args[arg_name]
    if found:
        return value
    if arg_proto.is_optional:
        return None
    assert False, "Arg '{}' not found for op '{}'".format(arg_proto.primary_arg_name, op_name)