Esempio n. 1
0
    def _run_returnn_standalone_net_dict(self):
        print(">>> Constructing RETURNN model, load TF checkpoint, run...")
        with tf.compat.v1.Session() as session:
            from returnn.config import Config
            from returnn.tf.network import TFNetwork
            config = Config({
                "extern_data": {
                    "data": self._returnn_in_data_dict
                },
                "debug_print_layer_output_template": True,
            })
            network = TFNetwork(config=config, name="root")
            network.construct_from_dict(self._returnn_net_dict)
            network.load_params_from_file(
                filename=self._tf_checkpoint_save_path, session=session)

            x = network.extern_data.get_default_input_data()
            y = network.get_default_output_layer().output
            feed_dict = self._make_tf_feed_dict(x)
            y_, y_size = session.run((y.placeholder, y.size_placeholder),
                                     feed_dict=feed_dict)
            assert isinstance(y_, numpy.ndarray)
            print("Output shape:", y_.shape)
            numpy.testing.assert_allclose(self._out_returnn_np, y_)
            print(">>>> Looks good!")
            print()
Esempio n. 2
0
    def _run_returnn_standalone_python(self):
        print(
            ">>> Constructing RETURNN model via Python code, load TF checkpoint, run..."
        )
        with tf.compat.v1.Session() as session:
            with Naming.make_instance(
            ) as naming:  # we expect this to work with the default settings
                model_func = self._model_func

                # Wrap the model_func in a module.
                # We assume this would be flattened away in the namespace.
                # All named modules should thus have the same names.
                class DummyModule(torch_returnn.nn.Module):
                    def get_returnn_name(self) -> str:
                        return ""  # also avoid that this name becomes a prefix anywhere

                    def forward(self, *inputs):
                        return model_func(wrapped_import_torch_returnn,
                                          *inputs)

                dummy_mod = DummyModule()
                net_dict = dummy_mod.as_returnn_net_dict(
                    self._returnn_in_data_dict)

            from returnn.config import Config
            from returnn.tf.network import TFNetwork
            config = Config({
                "extern_data": {
                    "data": self._returnn_in_data_dict
                },
                "debug_print_layer_output_template": True,
            })
            network = TFNetwork(config=config, name="root")
            network.construct_from_dict(net_dict)
            network.load_params_from_file(
                filename=self._tf_checkpoint_save_path, session=session)

            x = network.extern_data.get_default_input_data()
            y = network.get_default_output_layer().output
            feed_dict = self._make_tf_feed_dict(x)
            y_, y_size = session.run((y.placeholder, y.size_placeholder),
                                     feed_dict=feed_dict)
            assert isinstance(y_, numpy.ndarray)
            print("Output shape:", y_.shape)
            numpy.testing.assert_allclose(self._out_returnn_np, y_)
            print(">>>> Looks good!")
            print()