Пример #1
0
def convert_split(op):
    # type: (TFOperation)->None
    num_or_size_splits = op.attribs['num_or_size_splits']

    if not isinstance(num_or_size_splits, (list, tuple)) or len(
            utils.unique(num_or_size_splits)) == 1:
        op.name = "SPLIT"
        num_splits = len(num_or_size_splits) if isinstance(
            num_or_size_splits, (list, tuple)) else num_or_size_splits
        op.inputs = (TFTensor(graph=op.graph,
                              shape=[],
                              dtype='INT32',
                              data=[op.attribs["axis"]]), op.input)
        op.attribs = dict(num_splits=num_splits)
    else:
        op.name = "SPLIT_V"
        size_splits = list(num_or_size_splits)
        op.inputs = (op.input,
                     TFTensor(graph=op.graph,
                              shape=[len(size_splits)],
                              dtype='INT32',
                              data=size_splits),
                     TFTensor(graph=op.graph,
                              shape=[],
                              dtype='INT32',
                              data=[op.attribs["axis"]]))
        op.attribs = dict(num_splits=len(size_splits))
Пример #2
0
    def combine_configs(configs):
        assert all(isinstance(config, NNEFParserConfig) for config in configs)

        shapes = {}
        lowered = []
        for config in configs:
            shapes.update(config.shapes)
            lowered += config.lowered

        return NNEFParserConfig(fragments='\n\n'.join(config.fragments for config in configs),
                                shapes=shapes,
                                lowered=utils.unique(lowered))
Пример #3
0
    def net_fun_with_gradients():
        outputs_ = _eliminate_named_tuples(net_fun())
        outputs_dict = OrderedDict()

        def visit(output_, path):
            # type: (typing.Any, str)->None
            if isinstance(output_, tf.Variable):
                output_ = output_.value()
            if isinstance(output_, tf.Tensor):
                path = path[1:]
                if not path:
                    path = "output"
                elif path[0].isdigit():
                    path = "output" + path
                assert utils.is_identifier(path), \
                    "Bad name_override '{}' for tensor {}. " \
                    "Please use valid identifiers as keys in the dict(s) " \
                    "returned by your network function.".format(path, output_.name)
            outputs_dict[path] = output_

        _recursive_visit_with_path(outputs_, visit)

        inputs = get_placeholders() + tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES)

        grad_ys = None

        # We can test with other grad_ys too
        # grad_ys = [tf.constant(value=2.0, dtype=tf.float32, shape=output_.shape) for output_ in outputs]

        ys = [
            y for y in six.itervalues(outputs_dict)
            if y.dtype.name.startswith("float") or y.dtype.name.startswith(
                "int") or y.dtype.name.startswith("uint")
        ]

        gradients = [
            gradient
            for gradient in tf.gradients(ys=ys, xs=inputs, grad_ys=grad_ys)
            if gradient not in six.itervalues(outputs_dict)
        ]

        items = [(name_, output_)
                 for name_, output_ in six.iteritems(outputs_dict)]

        items += [("grad_{}".format(to_id(input_.name[:-2])), gradient)
                  for input_, gradient in zip(inputs, gradients)
                  if None not in [input_, gradient]]

        return OrderedDict(utils.unique(items, key=lambda item: item[1]))
Пример #4
0
    def combine_configs(configs):
        assert all(isinstance(config, NNEFParserConfig) for config in configs)

        shapes = {}
        for config in configs:
            shapes.update(config._shapes)

        expand = []
        for config in configs:
            expand += config._expand

        expand = utils.unique(expand)

        return NNEFParserConfig(source='\n\n'.join(config._source
                                                   for config in configs),
                                shapes=shapes,
                                expand=expand)