コード例 #1
0
    def __init__(
        self,
        name=None,  # type: typing.Union[None, str, typing.List[str]]
        inputs=None,  # type: typing.Optional[_TensorOrListOrTupleOrSetOrDict]
        outputs=None,  # type: typing.Optional[_TensorOrListOrTupleOrSetOrDict]
        attribs=None  # type: typing.Optional[typing.Dict[str, typing.Any]]
    ):
        # type: (...) -> None
        if isinstance(inputs, Tensor):
            inputs = (inputs, )
        if isinstance(outputs, Tensor):
            outputs = (outputs, )
        assert name is None or isinstance(name, (str, list, tuple))
        assert inputs is None or isinstance(inputs, (list, tuple, set, dict))
        assert outputs is None or isinstance(outputs, (list, tuple, set, dict))
        assert attribs is None or isinstance(attribs, dict)

        self.name = name
        self.inputs = inputs
        self.outputs = outputs
        self.attribs = attribs

        if self.outputs:

            def visit(tensor):
                assert tensor._producer_pattern is None
                tensor._producer_pattern = self

            utils.recursive_visit(outputs, visit)
コード例 #2
0
def _eliminate_nesting(invocations):
    # type: (typing.List[_Invocation])->typing.List[_Invocation]
    least_nested_producer = {}  # type: typing.Dict[tf.Tensor, _Invocation]
    for invocation in invocations:

        def add_producers(result_):
            if isinstance(result_, tf.Variable):
                result_ = result_.value()
            if isinstance(result_, tf.Tensor):
                if (result_ not in least_nested_producer
                        or invocation.nesting_level <
                        least_nested_producer[result_].nesting_level):
                    least_nested_producer[result_] = invocation

        utils.recursive_visit(invocation.result, add_producers)

    for invocation in invocations:
        invocation.tmp_level = invocation.nesting_level
        invocation.tmp_args_checked = False

    for invocation in invocations:
        if invocation.tmp_level == 0:
            _check_args_for_nested_reference(invocation, least_nested_producer)

    invocations = [
        invocation for invocation in invocations if invocation.tmp_level == 0
    ]

    for invocation in invocations:
        invocation.tmp_level = None
        invocation.tmp_args_checked = None

    return invocations
コード例 #3
0
ファイル: nnef_io.py プロジェクト: kiritigowda/NNEF-Tools
def _recursive_check_str(data):
    if sys.version_info[0] < 3:

        def check(arg):
            # noinspection PyUnresolvedReferences
            assert not isinstance(arg, unicode), \
                "NNEF module does not accept unicode strings in python2. Use NNEFGraph with str only."

        utils.recursive_visit(data, check)
    return data
コード例 #4
0
def _check_args_for_nested_reference(invocation, least_nested_producer):
    if invocation.tmp_args_checked:
        return
    invocation.tmp_args_checked = True

    def visit(arg):
        if isinstance(arg, tf.Variable):
            arg = arg.value()
        if isinstance(arg, tf.Tensor) and arg in least_nested_producer:
            producer = least_nested_producer[arg]
            if producer.tmp_level > 0:
                if producer.function_name == "tf.constant":
                    producer.tmp_level = 0
                else:
                    _bubble_up(producer.parent, least_nested_producer)

    utils.recursive_visit(invocation.args, visit)
コード例 #5
0
def _check_has_untraced_ops(invocations, result):
    has_untraced = [False]
    tensors = set()
    for invocation in invocations:

        def check_args(arg):
            if isinstance(arg, tf.Variable):
                arg = arg.value()
            if isinstance(arg, tf.Tensor) and arg not in tensors:
                print("Error: Untraced tensor: {}, used near: {}".format(
                    arg, _get_location_summary(invocation.stack)))

                has_untraced[0] = True
                tensors.add(arg)

        def add_results(result_):
            if isinstance(result_, tf.Variable):
                result_ = result_.value()
            if isinstance(result_, tf.Tensor):
                tensors.add(result_)

        utils.recursive_visit(invocation.args, check_args)
        utils.recursive_visit(invocation.result, add_results)

    def check_outputs(result_):
        if isinstance(result_, tf.Variable):
            result_ = result_.value()
        if isinstance(result_, tf.Tensor) and result_ not in tensors:
            print("Error: Untraced output tensor: {}".format(result_))

            has_untraced[0] = True
            tensors.add(result_)

    utils.recursive_visit(result, check_outputs)

    return has_untraced[0]
コード例 #6
0
    def _test(self,
              fun,
              cmp=True,
              custom_tf_to_nnef_converters="",
              custom_nnef_to_tf_converters="",
              test_module="nnef_tests.conversion.tf_py_layer_test_cases",
              atol=1e-5):

        activation_testing = int(os.environ.get('NNEF_ACTIVATION_TESTING',
                                                '1'))
        print("Activation testing is", "ON" if activation_testing else "OFF")

        out_dir = os.path.join("out", fun.__name__)
        try:
            tf.reset_default_graph()
            tf.set_random_seed(0)

            network_outputs = fun()
            feed_dict = get_feed_dict()
            old_names = [
                placeholder.name for placeholder in get_placeholders()
            ]
            checkpoint_path = os.path.join("out", fun.__name__,
                                           "orig_checkpoint",
                                           fun.__name__ + ".ckpt")
            checkpoint_path = save_random_checkpoint(network_outputs,
                                                     checkpoint_path,
                                                     feed_dict)

            tf.reset_default_graph()
            tf.set_random_seed(0)

            compress_nnef = False
            command = """
                ./nnef_tools/convert.py --input-format tensorflow-py \\
                                        --output-format nnef \\
                                        --input-model {module}.{network} {checkpoint} \\
                                        --output-model out/{network}/{network}.nnef{tgz} \\
                                        --custom-converters {custom} \\
                                        --permissive \\
                                        --io-transformation SMART_TF_NHWC_TO_NCHW \\
                                        {compress}
            """.format(checkpoint=checkpoint_path if checkpoint_path else "",
                       network=fun.__name__,
                       custom=" ".join(custom_tf_to_nnef_converters),
                       compress="--compress" if compress_nnef else "",
                       module=test_module,
                       tgz=".tgz" if compress_nnef else "")

            convert.convert_using_command(command)

            if activation_testing:
                tf.reset_default_graph()
                tf.set_random_seed(0)
                network_outputs = fun()
                network_output_list = []
                utils.recursive_visit(network_outputs,
                                      lambda t: network_output_list.append(t))
                # Flatten is needed because of MaxPoolWithArgMax objects
                outputs = utils.flatten(
                    self._run_tfpy(network_output_list, feed_dict,
                                   checkpoint_path))
            else:
                outputs = None

            prefer_nhwc_options = [True]
            if tf_has_cuda_gpu():
                prefer_nhwc_options += [False]
            for prefer_nhwc in prefer_nhwc_options:
                print("Converting to TensorFlow {}".format(
                    "NHWC" if prefer_nhwc else "NCHW"))
                data_format_str = ("nhwc" if prefer_nhwc else "nchw")
                tf_output_path = os.path.join(
                    "out", fun.__name__,
                    fun.__name__ + '_' + data_format_str + '.py')
                command = """
                    ./nnef_tools/convert.py --input-format nnef \\
                                            --output-format tensorflow-py \\
                                            --input-model out/{network}/{network}.nnef{tgz} \\
                                            --output-model {output} \\
                                            --io-transformation SMART_NCHW_TO_TF_NHWC \\
                                            --custom-converters {custom} \\
                                            --permissive
                """.format(network=fun.__name__,
                           custom=" ".join(custom_nnef_to_tf_converters),
                           tgz=".nnef.tgz" if compress_nnef else "",
                           output=tf_output_path)
                convert.convert_using_command(command)

                with open(os.path.join(tf_output_path), 'r') as f:
                    tf_src = f.read()

                # noinspection PyProtectedMember
                new_net_fun = tf_py_io._tfsource_to_function(
                    tf_src, fun.__name__)

                tf.reset_default_graph()
                tf.set_random_seed(0)

                if activation_testing:
                    tf.reset_default_graph()
                    tf.set_random_seed(0)
                    network_outputs = new_net_fun()
                    network_output_list = []
                    utils.recursive_visit(
                        network_outputs,
                        lambda t: network_output_list.append(t))
                    feed_dict2 = {
                        placeholder.name: feed_dict[old_names[i]]
                        for i, placeholder in enumerate(get_placeholders())
                    }
                    outputs2 = utils.flatten(
                        self._run_tfpy(
                            network_output_list, feed_dict2,
                            (os.path.join(tf_output_path + ".checkpoint")
                             if checkpoint_path else None)))

                    if cmp:
                        self.assertTrue(len(outputs) == len(outputs2))
                        for a, b in zip(outputs, outputs2):
                            if a.dtype == np.bool:
                                self.assertTrue(np.all(a == b))
                            else:
                                print('Max diff:', np.max(np.abs(a - b)))
                                self.assertTrue(np.all(np.isfinite(a)))
                                self.assertTrue(np.all(np.isfinite(b)))
                                self.assertTrue(np.allclose(a, b, atol=atol))

        finally:
            if self.delete_dats_and_checkpoints:
                dat_files = utils.recursive_glob(out_dir, "*.dat")
                checkpoints = utils.recursive_glob(out_dir, "*ckpt*")
                for file_name in set(dat_files + checkpoints):
                    os.remove(file_name)