Exemple #1
0
def rna_alignment_out(sources, **kwargs):
  from returnn.tf.util.data import Data

  log_probs = sources[0].output
  targets = sources[1].output
  encoder = sources[2].output
  enc_lens = encoder.get_sequence_lengths()
  return Data(name="rna_alignment", sparse=True, dim=eval("targetb_num_labels"), size_placeholder={0: enc_lens})
 def _make_tf_feed_dict(self, input: Data):
     assert input.batch_ndim == len(self._inputs_np.shape)
     assert all(input.batch_shape[i] in {None, self._inputs_np.shape[i]}
                for i in range(input.batch_ndim))
     n_batch = self._inputs_np.shape[input.batch_dim_axis]
     d = {input.placeholder: self._inputs_np}
     for i, size in input.size_placeholder.items():
         d[size] = [self._inputs_np.shape[input.get_batch_axis(i)]
                    ] * n_batch  # not so relevant
     return d
 def _prepare_module_call_returnn_inputs(self, call: _call.CallEntry):
   """
   It means this module has no forward, i.e. it is wrapped as RETURNN layer.
   We might need to make some inputs available, which are not available yet,
   e.g. constants, params, etc.
   """
   if not self.wrap_to_returnn_enabled:
     return
   call_parent_namespace = call.namespace.parent or self.root_namespace
   for x in call.inputs_flat:
     if x is None:
       continue
     assert isinstance(x, _tensor.TensorEntry)
     names = [name_ for name_ in x.names if name_.parent is call_parent_namespace]
     if names:
       continue
     if x.is_param:
       assert x.returnn_data
       assert x.returnn_data.placeholder is None
       from pytorch_to_returnn.torch.nn.modules import Variable
       param_name = x.get_canonical_name()
       if x.returnn_data.name == "_unnamed_param":
         x.returnn_data.name = f"param:{param_name}"
       parent_mod = x.get_canonical_parent_module()
       prefix = (parent_mod.get_canonical_name() + "_") if parent_mod else ""
       mod = Variable(param=x.tensor())
       self.modules[mod].canonical_name = prefix + param_name
       res = mod()
       res_tensor = self.tensors[res]
       assert isinstance(res_tensor, _tensor.TensorEntry)
       assert len(res_tensor.names) == 1
       assert res_tensor.returnn_data.placeholder is not None
       x.returnn_data.placeholder = res_tensor.returnn_data.placeholder
     elif not x.output_from_calls or x.is_const:
       # Assume this is a constant.
       const_name = x.get_canonical_name(fallback="unnamed_const")
       tensor = x.tensor()
       if not x.returnn_data:
         x.returnn_data = Data(
           name=f"const:{const_name}", shape=tensor.shape, dtype=tensor.dtype.name,
           batch_dim_axis=None, time_dim_axis=None)
         x.returnn_axis_from_torch_axis = {i: i for i in range(len(tensor.shape))}
       parent_mod = x.get_canonical_parent_module()
       prefix = (parent_mod.get_canonical_name() + "_") if parent_mod else ""
       from pytorch_to_returnn.torch.nn.modules import Constant
       mod = Constant(value=tensor)
       self.modules[mod].canonical_name = prefix + const_name
       res = mod()
       res_tensor = self.tensors[res]
       assert isinstance(res_tensor, _tensor.TensorEntry)
       assert res_tensor.returnn_data.placeholder is not None
       x.returnn_data.placeholder = res_tensor.returnn_data.placeholder
       x.is_const = True
     else:
       raise Exception(f"Cannot handle tensor {x}, via {x.output_from_calls} ...")
 def _get_shape_meta(data: Data) -> List[str]:
     _res = []
     for i in range(data.batch_ndim):
         if i == data.batch_dim_axis:
             _res.append("B")
         elif i == data.feature_dim_axis:
             _res.append("F")
         elif i in data.get_spatial_batch_axes():
             _res.append(
                 f"spatial:{data.get_spatial_batch_axes().index(i)}")
         else:
             raise Exception(f"not expected {data}, axis {i}")
     return _res
Exemple #5
0
 def __init__(self, *args, **kwargs):
     super(Parameter, self).__init__(*args, **kwargs)
     from returnn.tf.util.data import Data
     naming = Naming.get_instance()
     tensor_entry = naming.register_tensor(self)
     tensor_entry.is_param = True
     tensor_entry.returnn_data = Data(name="_unnamed_param",
                                      shape=self.shape,
                                      dtype=self.dtype.name,
                                      batch_dim_axis=None,
                                      time_dim_axis=None)
     tensor_entry.returnn_axis_from_torch_axis = {
         i: i
         for i in range(len(self.shape))
     }
Exemple #6
0
def test_base_get_output_shape_from_returnn_conv2d_dynamic():
  with Naming.make_instance() as naming:
    assert isinstance(naming, Naming)
    x = torch.Tensor(64, 1, 11, 13)
    x_ = naming.register_tensor(x)
    x_.returnn_data = Data(name="x", shape=(1, None, None), feature_dim_axis=1)
    x_.returnn_axis_from_torch_axis = {0: 0, 1: 1, 2: 2, 3: 3}

    net = TFNetwork(extern_data=ExternData())
    # E.g. conv layer, with padding "same".
    layer = InternalLayer(name="layer", network=net, out_type={"shape": (None, None, 32)})

    torch_shape, returnn_axis_from_torch_axis = torch.nn.Module._base_get_output_shape_from_returnn(
      inputs_flat=[x], layer=layer)
    assert returnn_axis_from_torch_axis == {0: 0, 1: 3, 2: 1, 3: 2}
    assert torch_shape == (64, 32, 11, 13)
Exemple #7
0
def test_base_get_output_shape_from_returnn_2d_reorder_dynamic():
  with Naming.make_instance() as naming:
    assert isinstance(naming, Naming)
    x = torch.Tensor(64, 1, 11, 13)
    x_ = naming.register_tensor(x)
    x_.returnn_data = Data(name="x", shape=(1, None, None), feature_dim_axis=1, auto_create_placeholders=True)
    x_.returnn_axis_from_torch_axis = {0: 0, 1: 1, 2: 2, 3: 3}
    y_data = x_.returnn_data.copy_move_axis(2, 3)
    assert y_data.get_dim_tag(3) == x_.returnn_data.get_dim_tag(2)

    net = TFNetwork(extern_data=ExternData())
    # E.g. softmax_over_spatial with axis="stag:time1"
    layer = InternalLayer(name="layer", network=net, output=y_data)

    # We expect from all Torch modules, that they don't reorder the spatial axes.
    # (If they do, they explicitly would overwrite the output shape logic.)
    torch_shape, returnn_axis_from_torch_axis = torch.nn.Module._base_get_output_shape_from_returnn(
      inputs_flat=[x], layer=layer)
    assert returnn_axis_from_torch_axis == {0: 0, 1: 1, 2: 3, 3: 2}
    assert torch_shape == (64, 1, 11, 13)
 def register_input(self, tensor: _types.Tensor, returnn_data: Data) -> Data:
   entry = self.register_tensor(tensor)
   entry.is_input = True
   entry.is_const = False
   self.inputs.append(tensor)
   assert tensor.dim() == returnn_data.batch_ndim
   assert all([dim in {tensor.shape[i], None} for i, dim in enumerate(returnn_data.batch_shape)])
   entry.returnn_data = Data(
     name=returnn_data.name, auto_create_placeholders=True,
     sparse=returnn_data.sparse,
     dim=returnn_data.dim,
     shape=returnn_data.shape,
     batch_dim_axis=returnn_data.batch_dim_axis,
     time_dim_axis=returnn_data.time_dim_axis,
     feature_dim_axis=returnn_data.feature_dim_axis_or_unspecified,
     available_for_inference=True)
   entry.returnn_axis_from_torch_axis = {i: i for i in range(returnn_data.batch_ndim)}
   self.root_namespace.register_input(tensor=entry)
   assert entry.returnn_data
   return entry.returnn_data
 def _returnn_dummy_call(self, *returnn_inputs: Dict[str, Any]) -> Naming:
     from pytorch_to_returnn.torch import from_numpy
     naming = Naming.get_instance()
     returnn_datas = []
     for i, kwargs in enumerate(returnn_inputs):
         kwargs = kwargs.copy()
         if "name" not in kwargs:
             kwargs["name"] = "data" if i == 0 else f"data:{i}"
         x = Data(**kwargs)
         returnn_datas.append(x)
     dummy_inputs_np = [
         self._make_returnn_dummy_input(x) for x in returnn_datas
     ]
     dummy_inputs_torch = [from_numpy(x) for x in dummy_inputs_np]
     for i in range(len(returnn_inputs)):
         naming.register_input(tensor=dummy_inputs_torch[i],
                               returnn_data=returnn_datas[i])
     out = self(*dummy_inputs_torch)
     assert isinstance(out, Tensor)
     naming.register_output(out)
     return naming
def _get_network_align(epoch0: int):
    net_dict = _get_network(full_sum_alignment=True,
                            target="bpe" if _task == "train" else "targetb")
    net_dict["#trainable"] = False  # disable training
    net_dict["#finish_all_data"] = True  # in case of multi-GPU training or so
    subnet = net_dict["output"]["unit"]
    subnet["fullsum_alignment"] = {
        "class":
        "eval",
        "from": ["output_log_prob", "base:data:" + _target, "base:encoder"],
        "eval":
        rna_fullsum_alignment,
        "out_type":
        lambda sources, **kwargs: Data(
            name="rna_alignment_output",
            sparse=True,
            dim=_targetb_num_labels,
            size_placeholder={0: sources[2].output.size_placeholder[0]}),
        "is_output_layer":
        True
    }
    align_dir = os.path.dirname(model)
    subnet["_align_dump"] = {
        "class":
        "hdf_dump",
        "from":
        "fullsum_alignment",
        "is_output_layer":
        True,
        "dump_per_run":
        True,
        "extend_existing_file":
        epoch0 % EpochSplit > 0,
        "filename": (lambda **opts: "%s/align.{dataset_name}.hdf".format(
            **opts) % align_dir),
    }
    return net_dict
Exemple #11
0
def rna_loss_out(sources, **kwargs):
  from returnn.tf.util.data  import Data
  return Data(name="rna_loss", shape=())
def _get_network(target: str,
                 full_sum_loss: bool = False,
                 full_sum_alignment: bool = False,
                 ce_loss: bool = False,
                 pretrain_frac: float = 1,
                 grow_encoder: bool = True):
    full_sum = full_sum_loss or full_sum_alignment
    net_dict = {"#config": {}}
    if pretrain_frac < _pretrain_warmup_lr_frac:
        start_lr = learning_rate / 10.
        net_dict["#config"]["learning_rate"] = start_lr + (
            1 / _pretrain_warmup_lr_frac) * pretrain_frac * learning_rate
    elif pretrain_frac < 1:  # constant for the rest of pretraining
        net_dict["#config"]["learning_rate"] = learning_rate

    EncKeyTotalDim = 200
    AttentionDropout = 0.1
    EncValueTotalDim = 2048
    LstmDim = 1024
    AttNumHeads = 1
    EncKeyPerHeadDim = EncKeyTotalDim // AttNumHeads
    l2 = 0.0001
    net_dict.update({
        "source": {
            "class": "eval",
            "eval": transform
        },
        "source0": {
            "class": "split_dims",
            "axis": "F",
            "dims": (-1, 1),
            "from": "source"
        },  # (T,40,1)

        # Lingvo: ep.conv_filter_shapes = [(3, 3, 1, 32), (3, 3, 32, 32)],  ep.conv_filter_strides = [(2, 2), (2, 2)]
        "conv0": {
            "class": "conv",
            "from": "source0",
            "padding": "same",
            "filter_size": (3, 3),
            "n_out": 32,
            "activation": None,
            "with_bias": True
        },  # (T,40,32)
        "conv0p": {
            "class": "pool",
            "mode": "max",
            "padding": "same",
            "pool_size": (1, 2),
            "from": "conv0"
        },  # (T,20,32)
        "conv1": {
            "class": "conv",
            "from": "conv0p",
            "padding": "same",
            "filter_size": (3, 3),
            "n_out": 32,
            "activation": None,
            "with_bias": True
        },  # (T,20,32)
        "conv1p": {
            "class": "pool",
            "mode": "max",
            "padding": "same",
            "pool_size": (1, 2),
            "from": "conv1"
        },  # (T,10,32)
        "conv_merged": {
            "class": "merge_dims",
            "from": "conv1p",
            "axes": "static"
        },  # (T,320)

        # Encoder LSTMs added below, resulting in "encoder0".
        "encoder": {
            "class": "copy",
            "from": "encoder0"
        },
        "enc_ctx0": {
            "class": "linear",
            "from": "encoder",
            "activation": None,
            "with_bias": False,
            "n_out": EncKeyTotalDim
        },
        "enc_ctx_win": {
            "class": "window",
            "from": "enc_ctx0",
            "window_size": 5
        },  # [B,T,W,D]
        "enc_val": {
            "class": "copy",
            "from": "encoder"
        },
        "enc_val_win": {
            "class": "window",
            "from": "enc_val",
            "window_size": 5
        },  # [B,T,W,D]
        "enc_seq_len": {
            "class": "length",
            "from": "encoder",
            "sparse": False
        },

        # for task "search" / search_output_layer
        "output_wo_b0": {
            "class": "masked_computation",
            "unit": {
                "class": "copy"
            },
            "from": "output",
            "mask": "output/output_emit"
        },
        "output_wo_b": {
            "class": "reinterpret_data",
            "from": "output_wo_b0",
            "set_sparse_dim": _target_num_labels
        },
        "decision": {
            "class": "decide",
            "from": "output_wo_b",
            "loss": "edit_distance",
            "target": _target,
            'only_on_search': True
        },
        "_target_masked": {
            "class": "masked_computation",
            "mask": "output/output_emit",
            "from": "output",
            "unit": {
                "class": "copy"
            }
        },
        "3_target_masked": {
            "class":
            "reinterpret_data",
            "from":
            "_target_masked",
            "set_sparse_dim":
            _target_num_labels,  # we masked blank away
            "enforce_batch_major":
            True,  # ctc not implemented otherwise...
            "register_as_extern_data":
            "targetb_masked" if _task == "train" else None
        },
    })

    # Add encoder BLSTM stack.
    start_num_lstm_layers = 2
    final_num_lstm_layers = 6
    start_dim_factor = 0.5
    if grow_encoder:
        num_lstm_layers = start_num_lstm_layers + int(
            (final_num_lstm_layers - start_num_lstm_layers) * pretrain_frac)
        grow_frac = 1.0 - float(final_num_lstm_layers - num_lstm_layers) / (
            final_num_lstm_layers - start_num_lstm_layers)
        dim_frac = start_dim_factor + (1.0 - start_dim_factor) * grow_frac
    else:
        num_lstm_layers = final_num_lstm_layers
        dim_frac = 1.
    time_reduction = [3, 2] if num_lstm_layers >= 3 else [6]
    src = "conv_merged"
    if num_lstm_layers >= 1:
        net_dict.update({
            "lstm0_fw": {
                "class": "rec",
                "unit": "nativelstm2",
                "n_out": int(LstmDim * dim_frac),
                "L2": l2,
                "direction": 1,
                "from": src,
                "trainable": True
            },
            "lstm0_bw": {
                "class": "rec",
                "unit": "nativelstm2",
                "n_out": int(LstmDim * dim_frac),
                "L2": l2,
                "direction": -1,
                "from": src,
                "trainable": True
            }
        })
        src = ["lstm0_fw", "lstm0_bw"]
    for i in range(1, num_lstm_layers):
        red = time_reduction[i - 1] if (i - 1) < len(time_reduction) else 1
        net_dict.update({
            "lstm%i_pool" % (i - 1): {
                "class": "pool",
                "mode": "max",
                "padding": "same",
                "pool_size": (red, ),
                "from": src
            }
        })
        src = "lstm%i_pool" % (i - 1)
        net_dict.update({
            "lstm%i_fw" % i: {
                "class": "rec",
                "unit": "nativelstm2",
                "n_out": int(LstmDim * dim_frac),
                "L2": l2,
                "direction": 1,
                "from": src,
                "dropout": 0.3 * dim_frac,
                "trainable": True
            },
            "lstm%i_bw" % i: {
                "class": "rec",
                "unit": "nativelstm2",
                "n_out": int(LstmDim * dim_frac),
                "L2": l2,
                "direction": -1,
                "from": src,
                "dropout": 0.3 * dim_frac,
                "trainable": True
            }
        })
        src = ["lstm%i_fw" % i, "lstm%i_bw" % i]
    net_dict["encoder0"] = {"class": "copy", "from": src}
    net_dict["lm_input0"] = {"class": "copy", "from": "data:%s" % target}
    net_dict["lm_input1"] = {
        "class": "prefix_in_time",
        "from": "lm_input0",
        "prefix": 0
    }
    net_dict["lm_input"] = {"class": "copy", "from": "lm_input1"}

    def get_output_dict(train, search, target, beam_size=beam_size):
        return {
            "class": "rec",
            "from": "encoder",
            "include_eos": True,
            "back_prop": (_task == "train") and train,
            "unit": {
                "am0": {
                    "class": "gather_nd",
                    "from": "base:encoder",
                    "position": "prev:t"
                },  # [B,D]
                "am": {
                    "class": "copy",
                    "from": "data:source" if _task == "train" else "am0"
                },
                "prev_out_non_blank": {
                    "class": "reinterpret_data",
                    "from": "prev:output_",
                    "set_sparse_dim": _target_num_labels
                },
                "lm_masked": {
                    "class": "masked_computation",
                    "mask": "prev:output_emit",
                    "from": "prev_out_non_blank",  # in decoding
                    "masked_from": "base:lm_input" if _task == "train" else
                    None,  # enables optimization if used
                    "unit": {
                        "class": "subnetwork",
                        "from": "data",
                        "subnetwork": {
                            "input_embed": {
                                "class": "linear",
                                "activation": None,
                                "with_bias": False,
                                "from": "data",
                                "n_out": 256
                            },
                            "embed_dropout": {
                                "class": "dropout",
                                "from": "input_embed",
                                "dropout": 0.2
                            },
                            # "lstm0": {"class": "rec", "unit": "nativelstm2", "n_out": LstmDim, "from": ["embed_dropout"], "L2": l2},
                            "lstm0_zoneout": {
                                "class": "rnn_cell",
                                "unit": "ZoneoutLSTM",
                                "unit_opts": {
                                    "zoneout_factor_cell": 0.15,
                                    "zoneout_factor_output": 0.05
                                },
                                "from": ["embed_dropout"],
                                "n_out": 500
                            },
                            "output": {
                                "class": "copy",
                                "from": "lstm0_zoneout"
                            }
                        }
                    }
                },
                "readout_in": {
                    "class": "linear",
                    "from": ["am", "lm_masked"],
                    "activation": None,
                    "n_out": 1000,
                    "L2": l2,
                    "dropout": 0.2,
                    "out_type": {
                        "batch_dim_axis":
                        2 if _task == "train" else 0,
                        "shape": (None, None, 1000) if _task == "train" else
                        (1000, ),
                        "time_dim_axis":
                        0 if _task == "train" else None
                    }
                },  # (T, U+1, B, 1000
                "readout": {
                    "class": "reduce_out",
                    "mode": "max",
                    "num_pieces": 2,
                    "from": "readout_in"
                },
                "label_log_prob": {
                    "class": "linear",
                    "from": "readout",
                    "activation": "log_softmax",
                    "dropout": 0.3,
                    "n_out": _target_num_labels
                },  # (B, T, U+1, 1030)
                "emit_prob0": {
                    "class": "linear",
                    "from": "readout",
                    "activation": None,
                    "n_out": 1,
                    "is_output_layer": True
                },  # (B, T, U+1, 1)
                "emit_log_prob": {
                    "class": "activation",
                    "from": "emit_prob0",
                    "activation": "log_sigmoid"
                },  # (B, T, U+1, 1)
                "blank_log_prob": {
                    "class": "eval",
                    "from": "emit_prob0",
                    "eval": "tf.compat.v1.log_sigmoid(-source(0))"
                },  # (B, T, U+1, 1)
                "label_emit_log_prob": {
                    "class": "combine",
                    "kind": "add",
                    "from": ["label_log_prob", "emit_log_prob"]
                },  # (B, T, U+1, 1), scaling factor in log-space
                "output_log_prob": {
                    "class": "copy",
                    "from": ["label_emit_log_prob", "blank_log_prob"]
                },  # (B, T, U+1, 1031)
                "output": {
                    "class":
                    'choice',
                    'target':
                    target,
                    'beam_size':
                    beam_size,
                    'from':
                    "output_log_prob",
                    "input_type":
                    "log_prob",
                    "initial_output":
                    0,
                    "length_normalization":
                    False,
                    "cheating":
                    "exclusive" if _task == "train" else None,
                    "explicit_search_sources": ["prev:out_str", "prev:output"]
                    if _task == "search" else None,
                    "custom_score_combine":
                    targetb_recomb_recog if _task == "search" else None
                },
                # switchout only applicable to viterbi training, added below.
                "output_": {
                    "class": "copy",
                    "from": "output",
                    "initial_output": 0
                },

                # "alignment_length0": {"class": "prefix_in_time", "from": "base:lm_input0", "repeat": "base:enc_seq_len", "prefix": 0, "register_as_extern_data": "alignment"},

                # "fullsum_alignment0": {
                # "class": "eval",
                # "from": ["alignment_length0", "output_log_prob", "base:data:" + _target, "base:encoder"],
                # "eval": rnnt_alignment,
                # "size_target": "alignment",
                # "out_type": lambda sources, **kwargs: Data(name="rnnt_alignment_output", sparse=True, dim=_targetb_num_labels,
                # size_placeholder={}),
                # "is_output_layer": True,
                # },
                "out_str": {
                    "class": "eval",
                    "from": ["prev:out_str", "output_emit", "output"],
                    "initial_output": None,
                    "out_type": {
                        "shape": (),
                        "dtype": "string"
                    },
                    "eval": out_str
                },
                "output_is_not_blank": {
                    "class": "compare",
                    "from": "output_",
                    "value": _targetb_blank_idx,
                    "kind": "not_equal",
                    "initial_output": True
                },

                # initial state=True so that we are consistent to the training and the initial state is correctly set.
                "output_emit": {
                    "class": "copy",
                    "from": "output_is_not_blank",
                    "is_output_layer": True,
                    "initial_output": True
                },
                "const0": {
                    "class": "constant",
                    "value": 0,
                    "collocate_with": ["du", "dt", "t", "u"],
                    "dtype": "int32"
                },
                "const1": {
                    "class": "constant",
                    "value": 1,
                    "collocate_with": ["du", "dt", "t", "u"],
                    "dtype": "int32"
                },

                # pos in target, [B]
                "du": {
                    "class": "switch",
                    "condition": "output_emit",
                    "true_from": "const1",
                    "false_from": "const0"
                },
                "u": {
                    "class": "combine",
                    "from": ["prev:u", "du"],
                    "kind": "add",
                    "initial_output": 0
                },

                # pos in input, [B]
                # output label: stay in t, otherwise advance t (encoder)
                "dt": {
                    "class": "switch",
                    "condition": "output_is_not_blank",
                    "true_from": "const0",
                    "false_from": "const1"
                },
                "t": {
                    "class": "combine",
                    "from": ["dt", "prev:t"],
                    "kind": "add",
                    "initial_output": 0
                },

                # stop at U+T
                # in recog: stop when all input has been consumed
                # in train: defined by target.
                "end": {
                    "class": "compare",
                    "from": ["t", "base:enc_seq_len"],
                    "kind": "greater"
                },
            },
            "max_seq_len": "max_len_from('base:encoder') * 3",
        }

    net_dict["output"] = get_output_dict(train=(_task == "train"),
                                         search=(_task != "train"),
                                         target=target,
                                         beam_size=beam_size)

    subnet = net_dict["output"]["unit"]

    if ce_loss:  # Viterbi training, uses a more powerful state-layer
        subnet["output_prob"] = {
            "class": "activation",
            "from": "output_log_prob",
            "activation": "exp",
            "target": target,
            "loss": "ce",
            "loss_opts": {
                "focal_loss_factor": 2.0
            }
        }
        if _task == "train":  # SwitchOut in training
            subnet["output_"] = {
                "class": "eval",
                "from": "output",
                "eval": switchout_target,
                "initial_output": 0
            }
    if full_sum:
        # Fullsum loss requires way more memory
        del net_dict["_target_masked"]
        del net_dict["3_target_masked"]
        # Dropout regularization
        net_dict["enc_ctx0"]["dropout"] = 0.2
        net_dict["enc_ctx0"]["L2"] = l2
        subnet["output_prob"] = {
            "class": "eval",
            "from":
            ["output_log_prob", "base:data:" + _target, "base:encoder"],
            "eval": rnnt_loss,
            "out_type":
            lambda sources, **kwargs: Data(name="rnnt_loss", shape=()),
            "loss": "as_is",
        }
    return net_dict
    def _run_torch_returnn_drop_in(self):
        print(
            ">>> Running with wrapped Torch import, wrapping replacement for PyTorch..."
        )
        torch.manual_seed(42)
        with tf.compat.v1.Session() as session:
            with Naming.make_instance(
                    wrap_to_returnn_enabled=True,
                    keep_orig_module_io_tensors=
                    True,  # it's only symbolic anyway in TF
                    import_params_from_torch_namespace=self._torch_namespace
            ) as naming:
                assert isinstance(naming, Naming)
                in_returnn = torch_returnn.from_numpy(self._inputs_np)
                assert isinstance(in_returnn, torch_returnn.Tensor)
                x = naming.register_input(
                    in_returnn, Data("data", **self._returnn_in_data_dict))
                out_returnn = self._model_func(wrapped_import_torch_returnn,
                                               in_returnn)
                assert isinstance(out_returnn, torch_returnn.Tensor)
                out_returnn_ = naming.register_output(out_returnn)
                y, returnn_axis_from_torch_axis = out_returnn_.returnn_data, out_returnn_.returnn_axis_from_torch_axis
                assert isinstance(y, Data)
                print("RETURNN output:", y, "axis map RETURNN<-Torch",
                      returnn_axis_from_torch_axis)
                print(">>>> Module naming hierarchy:")
                naming.root_namespace.dump()
                print(">>>> RETURNN net dict:")
                self._returnn_net_dict = naming.root_namespace.dump_as_returnn_net_dict(
                )
                pprint(self._returnn_net_dict)
                print(">>>> Root module calls:")
                pprint(dict(naming.get_root_module_calls()))
                torch_mods_with_params = naming.get_modules_with_params_by_abs_name(
                )
                print(">>>> Modules with params:")
                pprint(dict(torch_mods_with_params))

            feed_dict = self._make_tf_feed_dict(x)
            y_, y_size = session.run((y.placeholder, y.size_placeholder),
                                     feed_dict=feed_dict)
            assert isinstance(y_, numpy.ndarray)
            self._out_returnn_np = y_
            print("Output shape:", y_.shape)
            print("Output seq lens:", y_size)
            y_torch = y_.transpose(
                *[returnn_axis_from_torch_axis[i] for i in range(y_.ndim)])
            print("Output shape (converted to Torch):", y_torch.shape)
            if self._out_ref_np is not None:
                numpy.testing.assert_allclose(
                    self._out_ref_np, y_torch,
                    **naming.validate_allclose_kwargs)
                print(">>>> Looks good!")

            if self.export_tf_checkpoint_save_path or self.verify_returnn_standalone_model:
                returnn_net = naming.root_namespace.returnn_ctx.network
                returnn_net.print_network_info(name="RETURNN network")
                if self.export_tf_checkpoint_save_path:
                    self._tf_checkpoint_save_path = self.export_tf_checkpoint_save_path
                else:
                    tmp_dir = tempfile.mkdtemp("tmp-returnn-tf-checkpoint")
                    self._tf_checkpoint_save_path = tmp_dir + "/model"
                print(
                    f"Saving TF checkpoint to {self._tf_checkpoint_save_path!r}..."
                )
                returnn_net.global_train_step.load(0, session=session)
                returnn_net.save_params_to_file(
                    filename=self._tf_checkpoint_save_path, session=session)
                print()