Exemplo n.º 1
0
 def flatten(x):
     return utils.flatten(x) if isinstance(x, (list, tuple)) else x
Exemplo n.º 2
0
    def _test(self,
              fun,
              cmp=True,
              custom_tf_to_nnef_converters="",
              custom_nnef_to_tf_converters="",
              test_module="nnef_tests.conversion.tf_py_layer_test_cases",
              atol=1e-5):

        activation_testing = int(os.environ.get('NNEF_ACTIVATION_TESTING',
                                                '1'))
        print("Activation testing is", "ON" if activation_testing else "OFF")

        out_dir = os.path.join("out", fun.__name__)
        try:
            tf.reset_default_graph()
            tf.set_random_seed(0)

            network_outputs = fun()
            feed_dict = get_feed_dict()
            old_names = [
                placeholder.name for placeholder in get_placeholders()
            ]
            checkpoint_path = os.path.join("out", fun.__name__,
                                           "orig_checkpoint",
                                           fun.__name__ + ".ckpt")
            checkpoint_path = save_random_checkpoint(network_outputs,
                                                     checkpoint_path,
                                                     feed_dict)

            tf.reset_default_graph()
            tf.set_random_seed(0)

            compress_nnef = False
            command = """
                ./nnef_tools/convert.py --input-format tensorflow-py \\
                                        --output-format nnef \\
                                        --input-model {module}.{network} {checkpoint} \\
                                        --output-model out/{network}/{network}.nnef{tgz} \\
                                        --custom-converters {custom} \\
                                        --permissive \\
                                        --io-transformation SMART_TF_NHWC_TO_NCHW \\
                                        {compress}
            """.format(checkpoint=checkpoint_path if checkpoint_path else "",
                       network=fun.__name__,
                       custom=" ".join(custom_tf_to_nnef_converters),
                       compress="--compress" if compress_nnef else "",
                       module=test_module,
                       tgz=".tgz" if compress_nnef else "")

            convert.convert_using_command(command)

            if activation_testing:
                tf.reset_default_graph()
                tf.set_random_seed(0)
                network_outputs = fun()
                network_output_list = []
                utils.recursive_visit(network_outputs,
                                      lambda t: network_output_list.append(t))
                # Flatten is needed because of MaxPoolWithArgMax objects
                outputs = utils.flatten(
                    self._run_tfpy(network_output_list, feed_dict,
                                   checkpoint_path))
            else:
                outputs = None

            prefer_nhwc_options = [True]
            if tf_has_cuda_gpu():
                prefer_nhwc_options += [False]
            for prefer_nhwc in prefer_nhwc_options:
                print("Converting to TensorFlow {}".format(
                    "NHWC" if prefer_nhwc else "NCHW"))
                data_format_str = ("nhwc" if prefer_nhwc else "nchw")
                tf_output_path = os.path.join(
                    "out", fun.__name__,
                    fun.__name__ + '_' + data_format_str + '.py')
                command = """
                    ./nnef_tools/convert.py --input-format nnef \\
                                            --output-format tensorflow-py \\
                                            --input-model out/{network}/{network}.nnef{tgz} \\
                                            --output-model {output} \\
                                            --io-transformation SMART_NCHW_TO_TF_NHWC \\
                                            --custom-converters {custom} \\
                                            --permissive
                """.format(network=fun.__name__,
                           custom=" ".join(custom_nnef_to_tf_converters),
                           tgz=".nnef.tgz" if compress_nnef else "",
                           output=tf_output_path)
                convert.convert_using_command(command)

                with open(os.path.join(tf_output_path), 'r') as f:
                    tf_src = f.read()

                # noinspection PyProtectedMember
                new_net_fun = tf_py_io._tfsource_to_function(
                    tf_src, fun.__name__)

                tf.reset_default_graph()
                tf.set_random_seed(0)

                if activation_testing:
                    tf.reset_default_graph()
                    tf.set_random_seed(0)
                    network_outputs = new_net_fun()
                    network_output_list = []
                    utils.recursive_visit(
                        network_outputs,
                        lambda t: network_output_list.append(t))
                    feed_dict2 = {
                        placeholder.name: feed_dict[old_names[i]]
                        for i, placeholder in enumerate(get_placeholders())
                    }
                    outputs2 = utils.flatten(
                        self._run_tfpy(
                            network_output_list, feed_dict2,
                            (os.path.join(tf_output_path + ".checkpoint")
                             if checkpoint_path else None)))

                    if cmp:
                        self.assertTrue(len(outputs) == len(outputs2))
                        for a, b in zip(outputs, outputs2):
                            if a.dtype == np.bool:
                                self.assertTrue(np.all(a == b))
                            else:
                                print('Max diff:', np.max(np.abs(a - b)))
                                self.assertTrue(np.all(np.isfinite(a)))
                                self.assertTrue(np.all(np.isfinite(b)))
                                self.assertTrue(np.allclose(a, b, atol=atol))

        finally:
            if self.delete_dats_and_checkpoints:
                dat_files = utils.recursive_glob(out_dir, "*.dat")
                checkpoints = utils.recursive_glob(out_dir, "*ckpt*")
                for file_name in set(dat_files + checkpoints):
                    os.remove(file_name)