def main():
    args = arg_parser.parse_args()
    return_code = 0
    try:
        if args.cwd:
            os.chdir(args.cwd)
        init(extra_greeting="Delete old models.",
             configFilename=args.config or None,
             config_updates={
                 "use_tensorflow": True,
                 "need_data": False,
                 "device": "cpu"
             })
        from rnn import engine, config
        if args.model:
            config.set("model", args.model)
        if args.scores:
            config.set("learning_rate_file", args.scores)
        if args.dry_run:
            config.set("dry_run", True)
        engine.cleanup_old_models(ask_for_confirmation=True)

    except KeyboardInterrupt:
        return_code = 1
        print("KeyboardInterrupt", file=getattr(log, "v3", sys.stderr))
        if getattr(log, "verbose", [False] * 6)[5]:
            sys.excepthook(*sys.exc_info())
    finalize()
    if return_code:
        sys.exit(return_code)
Exemple #2
0
def init_returnn(config_fn, args):
  """
  :param str config_fn:
  :param args: arg_parse object
  """
  rnn.init_better_exchook()
  config_updates = {
    "log": [],
    "task": "eval",
    "need_data": False,  # we will load it explicitly
    "device": args.device if args.device else None}
  if args.epoch:
    config_updates["load_epoch"] = args.epoch
  if args.do_search:
    config_updates.update({
      "task": "search",
      "search_do_eval": False,
      "beam_size": args.beam_size,
      "max_seq_length": 0,
      })

  rnn.init(
    config_filename=config_fn,
    config_updates=config_updates, extra_greeting="RETURNN get-attention-weights starting up.")
  global config
  config = rnn.config
def init(configFilename, commandLineOptions):
    rnn.init(configFilename=configFilename,
             commandLineOptions=commandLineOptions,
             config_updates={"log": None},
             extra_greeting="CRNN dump-forward starting up.")
    rnn.engine.init_train_from_config(config=rnn.config,
                                      train_data=rnn.train_data)
def main():
  args = arg_parser.parse_args()
  return_code = 0
  try:
    if args.cwd:
      os.chdir(args.cwd)
    init(
      extra_greeting="Delete old models.",
      config_filename=args.config or None,
      config_updates={
        "use_tensorflow": True,
        "need_data": False,
        "device": "cpu"})
    from rnn import engine, config
    if args.model:
      config.set("model", args.model)
    if args.scores:
      config.set("learning_rate_file", args.scores)
    if args.dry_run:
      config.set("dry_run", True)
    engine.cleanup_old_models(ask_for_confirmation=True)

  except KeyboardInterrupt:
    return_code = 1
    print("KeyboardInterrupt", file=getattr(log, "v3", sys.stderr))
    if getattr(log, "verbose", [False] * 6)[5]:
      sys.excepthook(*sys.exc_info())
  finalize()
  if return_code:
    sys.exit(return_code)
Exemple #5
0
def init_returnn(config_fn, cmd_line_opts, args):
    """
  :param str config_fn:
  :param list[str] cmd_line_opts:
  :param args: arg_parse object
  """
    rnn.initBetterExchook()
    config_updates = {"log": [], "task": "eval", "need_data": False}
    if args.epoch:
        config_updates["load_epoch"] = args.epoch
    if args.do_search:
        config_updates.update({
            "task": "search",
            "search_do_eval": False,
            "beam_size": args.beam_size,
            "max_seq_length": 0,
        })

    rnn.init(configFilename=config_fn,
             commandLineOptions=cmd_line_opts,
             config_updates=config_updates,
             extra_greeting="RETURNN get-attention-weights starting up.")
    global config
    config = rnn.config
Exemple #6
0
def init(configFilename, commandLineOptions, args):
    rnn.initBetterExchook()
    config_updates = {
        "log": None,
        "task": "eval",
        "eval": "config:get_dataset(%r)" % args.data,
        "train": None,
        "dev": None,
        "need_data": True,
    }
    if args.epoch:
        config_updates["load_epoch"] = args.epoch
    if args.do_search:
        config_updates.update({
            "task": "search",
            "search_data": "config:get_dataset(%r)" % args.data,
            "search_do_eval": False,
            "beam_size": int(args.beam_size),
            "max_seq_length": 0,
        })

    rnn.init(configFilename=configFilename,
             commandLineOptions=commandLineOptions,
             config_updates=config_updates,
             extra_greeting="CRNN dump-forward starting up.")
    rnn.engine.init_train_from_config(config=rnn.config, train_data=None)

    if rnn.engine.pretrain:
        new_network_desc = rnn.engine.pretrain.get_network_json_for_epoch(
            rnn.engine.epoch)
        rnn.engine.maybe_init_new_network(new_network_desc)
    global config
    config = rnn.config
    config.set("log", [])
    rnn.initLog()
    print("CRNN get-attention-weights starting up.", file=log.v3)
def main():
  print("#####################################################")
  print("Loading t2t model + scoring")
  t2t_sess, t2t_tvars, t2t_inputs_ph, t2t_targets_ph, t2t_losses = t2t_score_file(FLAGS_score_file)


  print("#####################################################")
  print("Loading returnn config")

  rnn.init(
    config_updates={
      "optimize_move_layers_out": True,
      "use_tensorflow": True,
      "num_outputs": num_outputs,
      "num_inputs": num_inputs,
      "task": "nop", "log": None, "device": "cpu",
      "network": returnn_network,
      "debug_print_layer_output_template": True,
      "debug_add_check_numerics_on_output": False},
    extra_greeting="Import t2t model.")
  assert Util.BackendEngine.is_tensorflow_selected()
  config = rnn.config

  rnn.engine.init_train_from_config(config=config)
  network = rnn.engine.network
  assert isinstance(network, TFNetwork)


  print("t2t network model params:")
  t2t_params = {} # type: dict[str,tf.Variable]
  t2t_total_num_params = 0
  for v in t2t_tvars:
    key = v.name[:-2]
    t2t_params[key] = v
    print("  %s: %s, %s" % (key, v.shape, v.dtype.base_dtype.name))
    t2t_total_num_params += numpy.prod(v.shape.as_list())
  print("t2t total num params: %i" % t2t_total_num_params)




  print("Our network model params:")
  our_params = {}  # type: dict[str,tf.Variable]
  our_total_num_params = 0
  for v in network.get_params_list():
    key = v.name[:-2]
    our_params[key] = v
    print("  %s: %s, %s" % (key, v.shape, v.dtype.base_dtype.name))
    our_total_num_params += numpy.prod(v.shape.as_list())
  print("Our total num params: %i" % our_total_num_params)



  print("Loading t2t params into our network:")
  for ret_var_name, ret_var in our_params.items():
    if ret_var_name in ret_to_t2t:
      t2t_var_names = ret_to_t2t[ret_var_name]
      # in return QKV params are concatenated into one tensor
      if isinstance(t2t_var_names, tuple):
        # params_np = numpy.concatenate([t2t_params[var_name].eval(t2t_sess) for var_name in t2t_var_names], axis=1) # Not enough...
        # More complex stacking necessary: head1 Q, head1 K, head1 V,   head2 Q, head2 K, head2 V, ...
        q_np = t2t_params[t2t_var_names[0]].eval(t2t_sess)
        k_np = t2t_params[t2t_var_names[1]].eval(t2t_sess)
        v_np = t2t_params[t2t_var_names[2]].eval(t2t_sess)
        qkv_dim_total = 2*EncKeyTotalDim+EncValueTotalDim
        params_np = numpy.empty((EncValueTotalDim, qkv_dim_total) ,dtype=q_np.dtype)
        qkv_dim_per_head = qkv_dim_total // AttNumHeads
        for i in range(EncKeyPerHeadDim):
          params_np[:, i::qkv_dim_per_head]                      = q_np[:, i::EncKeyPerHeadDim]
        for i in range(EncKeyPerHeadDim):
          params_np[:, i+EncKeyPerHeadDim::qkv_dim_per_head]     = k_np[:, i::EncKeyPerHeadDim]
        for i in range(EncValuePerHeadDim):
          params_np[:, i+2*EncKeyPerHeadDim::qkv_dim_per_head]   = v_np[:, i::EncValuePerHeadDim]
      else:
        t2t_var = t2t_params[t2t_var_names]
        params_np = t2t_var.eval(t2t_sess)
      if ret_var_name in ['output/rec/output_prob/W']: # ToDo: Something else to transpose?
        params_np = params_np.transpose()
      if ret_var_name in ["source_embed_raw/W", 'output/rec/target_embed_raw/W']:
        #params_np = params_np * (EncValueTotalDim**0.5) # ToDo: Only because of weight-tying?
        print("loading %s with * (EncValueTotalDim**0.5) or doing it in config" % ret_var.name)
      ret_var.load(params_np, rnn.engine.tf_session)
      print("loaded %s" % ret_var.name)
    else:
      print("skpipped over %s" % ret_var.name)


  ret_ph_train = rnn.engine.tf_session.graph.get_tensor_by_name("global_tensor_train_flag/train_flag:0")
  ret_ph_data = rnn.engine.tf_session.graph.get_tensor_by_name("extern_data/placeholders/data/data:0")
  ret_ph_data_dim =  rnn.engine.tf_session.graph.get_tensor_by_name("extern_data/placeholders/data/data_dim0_size:0")
  ret_ph_classes=  rnn.engine.tf_session.graph.get_tensor_by_name("extern_data/placeholders/classes/classes:0")
  ret_ph_classes_dim = rnn.engine.tf_session.graph.get_tensor_by_name("extern_data/placeholders/classes/classes_dim0_size:0")

  #ret_feed = {ret_ph_train: True, ret_ph_data: [[11, 78, 42, 670, 2415, 2, 134, 2, 61, 522, 2, 847, 2, 3353, 15, 33, 2534, 1], [3,6]], ret_ph_data_dim: [18, 2],
  #            ret_ph_classes: [[4, 60, 18, 46, 26, 2937, 520, 2, 1317, 2, 10, 642, 4, 639, 1], [2,5]], ret_ph_classes_dim:[14, 2]}

  #src = [[78, 1,0], [2, 134, 1]]; src_lens = [2,3]; trg = [[4, 60, 1], [639, 1, 0]]; trg_lens = [3,2]; ret_feed = {ret_ph_train: False, ret_ph_data: src, ret_ph_data_dim: src_lens, ret_ph_classes: trg, ret_ph_classes_dim: trg_lens}; t2t_feed = {t2t_inputs_ph: src, t2t_targets_ph: trg}

  src = [[2, 134, 1]]; src_lens = [3]; trg = [[4, 60, 1]]; trg_lens = [3]; ret_feed = {ret_ph_train: False, ret_ph_data: src, ret_ph_data_dim: src_lens, ret_ph_classes: trg, ret_ph_classes_dim: trg_lens}; t2t_feed = {t2t_inputs_ph: src, t2t_targets_ph: trg}


  compare_acts(network, t2t_sess, ret_feed, t2t_feed, act_ret_to_t2t)


  # filtered = [op for op in t2t_sess.graph.get_operations() if '/encoder/layer_0/self_attention' in op.name and op.type == 'MatMul']
  # filtered = [op for op in rnn.engine.tf_session.graph.get_operations() if 'enc_1_self_att_/' in op.name and op.type == 'MatMul']
  # for op in filtered: print(op.name)
#



  ipdb.set_trace()
Exemple #8
0
def main():
    print("#####################################################")
    print("Loading t2t model + scoring")
    t2t_sess, t2t_tvars, t2t_inputs_ph, t2t_targets_ph, t2t_losses = t2t_score_file(
        FLAGS_score_file)

    print("#####################################################")
    print("Loading returnn config")

    rnn.init(config_updates={
        "optimize_move_layers_out": True,
        "use_tensorflow": True,
        "num_outputs": num_outputs,
        "num_inputs": num_inputs,
        "task": "nop",
        "log": None,
        "device": "cpu",
        "network": returnn_network,
        "debug_print_layer_output_template": True,
        "debug_add_check_numerics_on_output": False
    },
             extra_greeting="Import t2t model.")
    assert Util.BackendEngine.is_tensorflow_selected()
    config = rnn.config

    rnn.engine.init_train_from_config(config=config)
    network = rnn.engine.network
    assert isinstance(network, TFNetwork)

    print("t2t network model params:")
    t2t_params = {}  # type: dict[str,tf.Variable]
    t2t_total_num_params = 0
    for v in t2t_tvars:
        key = v.name[:-2]
        t2t_params[key] = v
        print("  %s: %s, %s" % (key, v.shape, v.dtype.base_dtype.name))
        t2t_total_num_params += numpy.prod(v.shape.as_list())
    print("t2t total num params: %i" % t2t_total_num_params)

    print("Our network model params:")
    our_params = {}  # type: dict[str,tf.Variable]
    our_total_num_params = 0
    for v in network.get_params_list():
        key = v.name[:-2]
        our_params[key] = v
        print("  %s: %s, %s" % (key, v.shape, v.dtype.base_dtype.name))
        our_total_num_params += numpy.prod(v.shape.as_list())
    print("Our total num params: %i" % our_total_num_params)

    print("Loading t2t params into our network:")
    for ret_var_name, ret_var in our_params.items():
        if ret_var_name in ret_to_t2t:
            t2t_var_names = ret_to_t2t[ret_var_name]
            # in return QKV params are concatenated into one tensor
            if isinstance(t2t_var_names, tuple):
                # params_np = numpy.concatenate([t2t_params[var_name].eval(t2t_sess) for var_name in t2t_var_names], axis=1) # Not enough...
                # More complex stacking necessary: head1 Q, head1 K, head1 V,   head2 Q, head2 K, head2 V, ...
                q_np = t2t_params[t2t_var_names[0]].eval(t2t_sess)
                k_np = t2t_params[t2t_var_names[1]].eval(t2t_sess)
                v_np = t2t_params[t2t_var_names[2]].eval(t2t_sess)
                qkv_dim_total = 2 * EncKeyTotalDim + EncValueTotalDim
                params_np = numpy.empty((EncValueTotalDim, qkv_dim_total),
                                        dtype=q_np.dtype)
                qkv_dim_per_head = qkv_dim_total // AttNumHeads
                for i in range(EncKeyPerHeadDim):
                    params_np[:,
                              i::qkv_dim_per_head] = q_np[:,
                                                          i::EncKeyPerHeadDim]
                for i in range(EncKeyPerHeadDim):
                    params_np[:, i + EncKeyPerHeadDim::
                              qkv_dim_per_head] = k_np[:, i::EncKeyPerHeadDim]
                for i in range(EncValuePerHeadDim):
                    params_np[:, i + 2 * EncKeyPerHeadDim::
                              qkv_dim_per_head] = v_np[:,
                                                       i::EncValuePerHeadDim]
            else:
                t2t_var = t2t_params[t2t_var_names]
                params_np = t2t_var.eval(t2t_sess)
            if ret_var_name in ['output/rec/output_prob/W'
                                ]:  # ToDo: Something else to transpose?
                params_np = params_np.transpose()
            if ret_var_name in [
                    "source_embed_raw/W", 'output/rec/target_embed_raw/W'
            ]:
                #params_np = params_np * (EncValueTotalDim**0.5) # ToDo: Only because of weight-tying?
                print(
                    "loading %s with * (EncValueTotalDim**0.5) or doing it in config"
                    % ret_var.name)
            ret_var.load(params_np, rnn.engine.tf_session)
            print("loaded %s" % ret_var.name)
        else:
            print("skpipped over %s" % ret_var.name)

    ret_ph_train = rnn.engine.tf_session.graph.get_tensor_by_name(
        "global_tensor_train_flag/train_flag:0")
    ret_ph_data = rnn.engine.tf_session.graph.get_tensor_by_name(
        "extern_data/placeholders/data/data:0")
    ret_ph_data_dim = rnn.engine.tf_session.graph.get_tensor_by_name(
        "extern_data/placeholders/data/data_dim0_size:0")
    ret_ph_classes = rnn.engine.tf_session.graph.get_tensor_by_name(
        "extern_data/placeholders/classes/classes:0")
    ret_ph_classes_dim = rnn.engine.tf_session.graph.get_tensor_by_name(
        "extern_data/placeholders/classes/classes_dim0_size:0")

    #ret_feed = {ret_ph_train: True, ret_ph_data: [[11, 78, 42, 670, 2415, 2, 134, 2, 61, 522, 2, 847, 2, 3353, 15, 33, 2534, 1], [3,6]], ret_ph_data_dim: [18, 2],
    #            ret_ph_classes: [[4, 60, 18, 46, 26, 2937, 520, 2, 1317, 2, 10, 642, 4, 639, 1], [2,5]], ret_ph_classes_dim:[14, 2]}

    #src = [[78, 1,0], [2, 134, 1]]; src_lens = [2,3]; trg = [[4, 60, 1], [639, 1, 0]]; trg_lens = [3,2]; ret_feed = {ret_ph_train: False, ret_ph_data: src, ret_ph_data_dim: src_lens, ret_ph_classes: trg, ret_ph_classes_dim: trg_lens}; t2t_feed = {t2t_inputs_ph: src, t2t_targets_ph: trg}

    src = [[2, 134, 1]]
    src_lens = [3]
    trg = [[4, 60, 1]]
    trg_lens = [3]
    ret_feed = {
        ret_ph_train: False,
        ret_ph_data: src,
        ret_ph_data_dim: src_lens,
        ret_ph_classes: trg,
        ret_ph_classes_dim: trg_lens
    }
    t2t_feed = {
        t2t_inputs_ph: src,
        t2t_targets_ph: trg
    }

    compare_acts(network, t2t_sess, ret_feed, t2t_feed, act_ret_to_t2t)

    # filtered = [op for op in t2t_sess.graph.get_operations() if '/encoder/layer_0/self_attention' in op.name and op.type == 'MatMul']
    # filtered = [op for op in rnn.engine.tf_session.graph.get_operations() if 'enc_1_self_att_/' in op.name and op.type == 'MatMul']
    # for op in filtered: print(op.name)
    #

    ipdb.set_trace()
Exemple #9
0
def init(configFilename, commandLineOptions):
  rnn.init(
    config_filename=configFilename, command_line_options=commandLineOptions,
    config_updates={"log": None},
    extra_greeting="CRNN dump-forward starting up.")
  rnn.engine.init_train_from_config(config=rnn.config, train_data=rnn.train_data)
def main():
  rnn.init(
    command_line_options=sys.argv[1:],
    config_updates={
      "task": "nop", "log": None, "device": "cpu",
      "allow_random_model_init": True,
      "debug_add_check_numerics_on_output": False},
    extra_greeting="Import Blocks MT model.")
  assert Util.BackendEngine.is_tensorflow_selected()
  config = rnn.config

  # Load Blocks MT model params.
  if not config.has("blocks_mt_model"):
    print("Please provide the option blocks_mt_model.")
    sys.exit(1)
  blocks_mt_model_fn = config.value("blocks_mt_model", "")
  assert blocks_mt_model_fn
  assert os.path.exists(blocks_mt_model_fn)
  if os.path.isdir(blocks_mt_model_fn):
    blocks_mt_model_fn += "/params.npz"
    assert os.path.exists(blocks_mt_model_fn)

  dry_run = config.bool("dry_run", False)
  if dry_run:
    our_model_fn = None
    print("Dry-run, will not save model.")
  else:
    our_model_fn = config.value('model', "returnn-model") + ".imported"
    print("Will save Returnn model as %s." % our_model_fn)
    assert os.path.exists(os.path.dirname(our_model_fn) or "."), "model-dir does not exist"
    assert not os.path.exists(our_model_fn + Util.get_model_filename_postfix()), "model-file already exists"

  blocks_mt_model = numpy.load(blocks_mt_model_fn)
  assert isinstance(blocks_mt_model, numpy.lib.npyio.NpzFile), "did not expect type %r in file %r" % (
    type(blocks_mt_model), blocks_mt_model_fn)
  print("Params found in Blocks model:")
  blocks_params = {}  # type: dict[str,numpy.ndarray]
  blocks_params_hierarchy = {}  # type: dict[str,dict[str]]
  blocks_total_num_params = 0
  for key in sorted(blocks_mt_model.keys()):
    value = blocks_mt_model[key]
    key = key.replace("-", "/")
    assert key[0] == "/"
    key = key[1:]
    blocks_params[key] = value
    print("  %s: %s, %s" % (key, value.shape, value.dtype))
    blocks_total_num_params += numpy.prod(value.shape)
    d = blocks_params_hierarchy
    for part in key.split("/"):
      d = d.setdefault(part, {})
  print("Blocks total num params: %i" % blocks_total_num_params)

  # Init our network structure.
  from TFNetworkRecLayer import _SubnetworkRecCell
  _SubnetworkRecCell._debug_out = []  # enable for debugging intermediate values below
  ChoiceLayer._debug_out = []  # also for debug outputs of search
  rnn.engine.use_search_flag = True  # construct the net as in search
  rnn.engine.init_network_from_config()
  print("Our network model params:")
  our_params = {}  # type: dict[str,tf.Variable]
  our_total_num_params = 0
  for v in rnn.engine.network.get_params_list():
    key = v.name[:-2]
    our_params[key] = v
    print("  %s: %s, %s" % (key, v.shape, v.dtype.base_dtype.name))
    our_total_num_params += numpy.prod(v.shape.as_list())
  print("Our total num params: %i" % our_total_num_params)

  # Now matching...
  blocks_used_params = set()  # type: set[str]
  our_loaded_params = set()  # type: set[str]

  def import_var(our_var, blocks_param):
    """
    :param tf.Variable our_var:
    :param str|numpy.ndarray blocks_param:
    """
    assert isinstance(our_var, tf.Variable)
    if isinstance(blocks_param, str):
      blocks_param = load_blocks_var(blocks_param)
    assert isinstance(blocks_param, numpy.ndarray)
    assert_equal(tuple(our_var.shape.as_list()), blocks_param.shape)
    our_loaded_params.add(our_var.name[:-2])
    our_var.load(blocks_param, session=rnn.engine.tf_session)

  def load_blocks_var(blocks_param_name):
    """
    :param str blocks_param_name:
    :rtype: numpy.ndarray
    """
    assert isinstance(blocks_param_name, str)
    assert blocks_param_name in blocks_params
    blocks_used_params.add(blocks_param_name)
    return blocks_params[blocks_param_name]

  enc_name = "bidirectionalencoder"
  enc_embed_name = "EncoderLookUp0.W"
  assert enc_name in blocks_params_hierarchy
  assert enc_embed_name in blocks_params_hierarchy[enc_name]  # input embedding
  num_encoder_layers = max([
    int(re.match(".*([0-9]+)", s).group(1))
    for s in blocks_params_hierarchy[enc_name]
    if s.startswith("EncoderBidirectionalLSTM")])
  blocks_input_dim, blocks_input_embed_dim = blocks_params["%s/%s" % (enc_name, enc_embed_name)].shape
  print("Blocks input dim: %i, embed dim: %i" % (blocks_input_dim, blocks_input_embed_dim))
  print("Blocks num encoder layers: %i" % num_encoder_layers)
  expected_enc_entries = (
    ["EncoderLookUp0.W"] +
    ["EncoderBidirectionalLSTM%i" % i for i in range(1, num_encoder_layers + 1)])
  assert_equal(set(expected_enc_entries), set(blocks_params_hierarchy[enc_name].keys()))

  our_input_layer = find_our_input_embed_layer()
  assert our_input_layer.input_data.dim == blocks_input_dim
  assert our_input_layer.output.dim == blocks_input_embed_dim
  assert not our_input_layer.with_bias
  import_var(our_input_layer.params["W"], "%s/%s" % (enc_name, enc_embed_name))

  dec_name = "decoder/sequencegenerator"
  dec_hierarchy_base = get_in_hierarchy(dec_name, blocks_params_hierarchy)
  assert_equal(set(dec_hierarchy_base.keys()), {"att_trans", "readout"})
  dec_embed_name = "readout/lookupfeedbackwmt15/lookuptable.W"
  get_in_hierarchy(dec_embed_name, dec_hierarchy_base)  # check

  for i in range(num_encoder_layers):
    # Assume standard LSTMCell.
    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    # lstm_matrix = self._linear1([inputs, m_prev])
    # i, j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=4, axis=1)
    # bias (4*in), kernel (in+out,4*out), w_(f|i|o)_diag (out)
    # prefix: rec/rnn/lstm_cell
    # Blocks: gate-in, gate-forget, next-in, gate-out
    for direction in ("fwd", "bwd"):
      our_layer = get_network().layers["lstm%i_%s" % (i, direction[:2])]
      blocks_prefix = "bidirectionalencoder/EncoderBidirectionalLSTM%i" % (i + 1,)
      # (in,out*4), (out*4,)
      W_in, b = [load_blocks_var(
        "%s/%s_fork/fork_inputs.%s" % (blocks_prefix, {"bwd": "back", "fwd": "fwd"}[direction], p))
        for p in ("W", "b")]
      W_re = load_blocks_var(
        "%s/bidirectionalseparateparameters/%s.W_state" % (blocks_prefix, {"fwd": "forward", "bwd": "backward"}[direction]))
      W = numpy.concatenate([W_in, W_re], axis=0)
      b = lstm_vec_blocks_to_tf(b)
      W = lstm_vec_blocks_to_tf(W)
      import_var(our_layer.params["rnn/lstm_cell/bias"], b)
      import_var(our_layer.params["rnn/lstm_cell/kernel"], W)
      import_var(our_layer.params["initial_c"], "%s/bidirectionalseparateparameters/%s.initial_cells" % (blocks_prefix, {"fwd": "forward", "bwd": "backward"}[direction]))
      import_var(our_layer.params["initial_h"], "%s/bidirectionalseparateparameters/%s.initial_state" % (blocks_prefix, {"fwd": "forward", "bwd": "backward"}[direction]))
      for s1, s2 in [("W_cell_to_in", "w_i_diag"), ("W_cell_to_forget", "w_f_diag"), ("W_cell_to_out", "w_o_diag")]:
        import_var(our_layer.params["rnn/lstm_cell/%s" % s2], "%s/bidirectionalseparateparameters/%s.%s" % (blocks_prefix, {"fwd": "forward", "bwd": "backward"}[direction], s1))
  import_var(get_network().layers["enc_ctx"].params["W"], "decoder/sequencegenerator/att_trans/attention/encoder_state_transformer.W")
  import_var(get_network().layers["enc_ctx"].params["b"], "decoder/sequencegenerator/att_trans/attention/encoder_state_transformer.b")
  import_var(our_params["output/rec/s/initial_c"], "decoder/sequencegenerator/att_trans/lstm_decoder.initial_cells")
  import_var(our_params["output/rec/s/initial_h"], "decoder/sequencegenerator/att_trans/lstm_decoder.initial_state")
  import_var(our_params["output/rec/weight_feedback/W"], "decoder/sequencegenerator/att_trans/attention/sum_alignment_transformer.W")
  import_var(our_params["output/rec/target_embed/W"], "decoder/sequencegenerator/readout/lookupfeedbackwmt15/lookuptable.W")
  import_var(our_params["fertility/W"], "decoder/sequencegenerator/att_trans/attention/fertility_transformer.W")
  import_var(our_params["output/rec/energy/W"], "decoder/sequencegenerator/att_trans/attention/energy_comp/linear.W")
  prev_s_trans_W_states = load_blocks_var("decoder/sequencegenerator/att_trans/attention/state_trans/transform_states.W")
  prev_s_trans_W_cells = load_blocks_var("decoder/sequencegenerator/att_trans/attention/state_trans/transform_cells.W")
  prev_s_trans_W = numpy.concatenate([prev_s_trans_W_cells, prev_s_trans_W_states], axis=0)
  import_var(our_params["output/rec/prev_s_transformed/W"], prev_s_trans_W)
  import_var(our_params["output/rec/s/rec/lstm_cell/bias"], numpy.zeros(our_params["output/rec/s/rec/lstm_cell/bias"].shape))
  dec_lstm_kernel_in_feedback = load_blocks_var("decoder/sequencegenerator/att_trans/feedback_to_decoder/fork_inputs.W")
  dec_lstm_kernel_in_ctx = load_blocks_var("decoder/sequencegenerator/att_trans/context_to_decoder/fork_inputs.W")
  dec_lstm_kernel_re = load_blocks_var("decoder/sequencegenerator/att_trans/lstm_decoder.W_state")
  dec_lstm_kernel = numpy.concatenate([dec_lstm_kernel_in_feedback, dec_lstm_kernel_in_ctx, dec_lstm_kernel_re], axis=0)
  dec_lstm_kernel = lstm_vec_blocks_to_tf(dec_lstm_kernel)
  import_var(our_params["output/rec/s/rec/lstm_cell/kernel"], dec_lstm_kernel)
  for s1, s2 in [("W_cell_to_in", "w_i_diag"), ("W_cell_to_forget", "w_f_diag"), ("W_cell_to_out", "w_o_diag")]:
    import_var(our_params["output/rec/s/rec/lstm_cell/%s" % s2], "decoder/sequencegenerator/att_trans/lstm_decoder.%s" % s1)
  readout_in_W_states = load_blocks_var("decoder/sequencegenerator/readout/merge/transform_states.W")
  readout_in_W_feedback = load_blocks_var("decoder/sequencegenerator/readout/merge/transform_feedback.W")
  readout_in_W_att = load_blocks_var("decoder/sequencegenerator/readout/merge/transform_weighted_averages.W")
  readout_in_W = numpy.concatenate([readout_in_W_states, readout_in_W_feedback, readout_in_W_att], axis=0)
  import_var(our_params["output/rec/readout_in/W"], readout_in_W)
  import_var(our_params["output/rec/readout_in/b"], "decoder/sequencegenerator/readout/initializablefeedforwardsequence/maxout_bias.b")
  import_var(our_params["output/rec/output_prob/W"], "decoder/sequencegenerator/readout/initializablefeedforwardsequence/softmax1.W")
  import_var(our_params["output/rec/output_prob/b"], "decoder/sequencegenerator/readout/initializablefeedforwardsequence/softmax1.b")

  print("Not initialized own params:")
  count = 0
  for key, v in sorted(our_params.items()):
    if key in our_loaded_params:
      continue
    print("  %s: %s, %s" % (key, v.shape, v.dtype.base_dtype.name))
    count += 1
  if not count:
    print("  None.")
  print("Not used Blocks params:")
  count = 0
  for key, value in sorted(blocks_params.items()):
    if key in blocks_used_params:
      continue
    print("  %s: %s, %s" % (key, value.shape, value.dtype))
    count += 1
  if not count:
    print("  None.")
  print("Done.")

  blocks_debug_dump_output = config.value("blocks_debug_dump_output", None)
  if blocks_debug_dump_output:
    print("Will read Blocks debug dump output from %r and compare with Returnn outputs." % blocks_debug_dump_output)
    blocks_initial_outputs = numpy.load("%s/initial_states_data.0.npz" % blocks_debug_dump_output)
    blocks_search_log = pickle.load(open("%s/search.log.pkl" % blocks_debug_dump_output, "rb"), encoding="bytes")
    blocks_search_log = {d[b"step"]: d for d in blocks_search_log}
    input_seq = blocks_initial_outputs["input"]
    beam_size, seq_len = input_seq.shape
    input_seq = input_seq[0]  # all the same, select beam 0
    assert isinstance(input_seq, numpy.ndarray)
    print("Debug input seq: %s" % input_seq.tolist())
    from GeneratingDataset import StaticDataset
    dataset = StaticDataset(
      data=[{"data": input_seq}],
      output_dim={"data": get_network().extern_data.get_default_input_data().get_kwargs()})
    dataset.init_seq_order(epoch=0)
    extract_output_dict = {
      "enc_src_emb": get_network().layers["source_embed"].output.get_placeholder_as_batch_major(),
      "encoder": get_network().layers["encoder"].output.get_placeholder_as_batch_major(),
      "enc_ctx": get_network().layers["enc_ctx"].output.get_placeholder_as_batch_major(),
      "output": get_network().layers["output"].output.get_placeholder_as_batch_major()
    }
    from TFNetworkLayer import concat_sources
    for i in range(num_encoder_layers):
      extract_output_dict["enc_layer_%i" % i] = concat_sources(
        [get_network().layers["lstm%i_fw" % i], get_network().layers["lstm%i_bw" % i]]
      ).get_placeholder_as_batch_major()
    extract_output_dict["enc_layer_0_fwd"] = get_network().layers["lstm0_fw"].output.get_placeholder_as_batch_major()
    our_output = rnn.engine.run_single(
      dataset=dataset, seq_idx=0, output_dict=extract_output_dict)
    blocks_out = blocks_initial_outputs["bidirectionalencoder_EncoderLookUp0__EncoderLookUp0_apply_output"]
    our_out = our_output["enc_src_emb"]
    print("our enc emb shape:", our_out.shape)
    print("Blocks enc emb shape:", blocks_out.shape)
    assert our_out.shape[:2] == (1, seq_len)
    assert blocks_out.shape[:2] == (seq_len, beam_size)
    assert our_out.shape[2] == blocks_out.shape[2]
    assert_almost_equal(our_out[0], blocks_out[:, 0], decimal=5)
    blocks_lstm0_out_ref = calc_lstm(blocks_out[:, 0], blocks_params)
    blocks_lstm0_out = blocks_initial_outputs["bidirectionalencoder_EncoderBidirectionalLSTM1_bidirectionalseparateparameters_forward__forward_apply_states"]
    our_lstm0_out = our_output["enc_layer_0_fwd"]
    assert blocks_lstm0_out.shape == (seq_len, beam_size) + blocks_lstm0_out_ref.shape
    assert our_lstm0_out.shape == (1, seq_len) + blocks_lstm0_out_ref.shape
    assert_almost_equal(blocks_lstm0_out[0, 0], blocks_lstm0_out_ref, decimal=6)
    print("Blocks LSTM0 frame 0 matched to ref calc.")
    assert_almost_equal(our_lstm0_out[0, 0], blocks_lstm0_out_ref, decimal=6)
    print("Our LSTM0 frame 0 matched to ref calc.")
    for i in range(num_encoder_layers):
      blocks_out = blocks_initial_outputs[
        "bidirectionalencoder_EncoderBidirectionalLSTM%i_bidirectionalseparateparameters__bidirectionalseparateparameters_apply_output_0" % (i + 1,)]
      our_out = our_output["enc_layer_%i" % i]
      print("our enc layer %i shape:" % i, our_out.shape)
      print("Blocks enc layer %i shape:" % i, blocks_out.shape)
      assert our_out.shape[:2] == (1, seq_len)
      assert blocks_out.shape[:2] == (seq_len, beam_size)
      assert our_out.shape[2] == blocks_out.shape[2]
      assert_almost_equal(our_out[0], blocks_out[:, 0], decimal=6)
    print("our encoder shape:", our_output["encoder"].shape)
    blocks_encoder_out = blocks_initial_outputs["bidirectionalencoder__bidirectionalencoder_apply_representation"]
    print("Blocks encoder shape:", blocks_encoder_out.shape)
    assert our_output["encoder"].shape[:2] == (1, seq_len)
    assert blocks_encoder_out.shape[:2] == (seq_len, beam_size)
    assert our_output["encoder"].shape[2] == blocks_encoder_out.shape[2]
    assert_almost_equal(our_output["encoder"][0], blocks_encoder_out[:, 0], decimal=6)
    blocks_first_frame_outputs = numpy.load("%s/next_states.0.npz" % blocks_debug_dump_output)
    blocks_enc_ctx_out = blocks_first_frame_outputs["decoder_sequencegenerator_att_trans_attention__attention_preprocess_preprocessed_attended"]
    our_enc_ctx_out = our_output["enc_ctx"]
    print("Blocks enc ctx shape:", blocks_enc_ctx_out.shape)
    assert blocks_enc_ctx_out.shape[:2] == (seq_len, beam_size)
    assert our_enc_ctx_out.shape[:2] == (1, seq_len)
    assert blocks_enc_ctx_out.shape[2:] == our_enc_ctx_out.shape[2:]
    assert_almost_equal(blocks_enc_ctx_out[:, 0], our_enc_ctx_out[0], decimal=5)
    fertility = numpy.dot(blocks_encoder_out[:, 0], blocks_params["decoder/sequencegenerator/att_trans/attention/fertility_transformer.W"])
    fertility = sigmoid(fertility)
    assert fertility.shape == (seq_len, 1)
    fertility = fertility[:, 0]
    assert fertility.shape == (seq_len,)
    our_dec_outputs = {v["step"]: v for v in _SubnetworkRecCell._debug_out}
    assert our_dec_outputs
    print("our dec frame keys:", sorted(our_dec_outputs[0].keys()))
    our_dec_search_outputs = {v["step"]: v for v in ChoiceLayer._debug_out}
    assert our_dec_search_outputs
    print("our dec search frame keys:", sorted(our_dec_search_outputs[0].keys()))
    print("Blocks search frame keys:", sorted(blocks_search_log[0].keys()))
    dec_lookup = blocks_params["decoder/sequencegenerator/readout/lookupfeedbackwmt15/lookuptable.W"]
    last_lstm_state = blocks_params["decoder/sequencegenerator/att_trans/lstm_decoder.initial_state"]
    last_lstm_cells = blocks_params["decoder/sequencegenerator/att_trans/lstm_decoder.initial_cells"]
    last_accumulated_weights = numpy.zeros((seq_len,), dtype="float32")
    last_output = 0
    dec_seq_len = 0
    for dec_step in range(100):
      blocks_frame_state_outputs_fn = "%s/next_states.%i.npz" % (blocks_debug_dump_output, dec_step)
      blocks_frame_probs_outputs_fn = "%s/logprobs.%i.npz" % (blocks_debug_dump_output, dec_step)
      if dec_step > 3:
        if not os.path.exists(blocks_frame_state_outputs_fn) or not os.path.exists(blocks_frame_probs_outputs_fn):
          print("Seq not ended yet but frame not found for step %i." % dec_step)
          break
      blocks_frame_state_outputs = numpy.load(blocks_frame_state_outputs_fn)
      blocks_frame_probs_outputs = numpy.load(blocks_frame_probs_outputs_fn)
      blocks_search_frame = blocks_search_log[dec_step]
      our_dec_frame_outputs = our_dec_outputs[dec_step]
      assert our_dec_frame_outputs["step"] == dec_step
      assert our_dec_frame_outputs[":i.output"].tolist() == [dec_step]
      our_dec_search_frame_outputs = our_dec_search_outputs[dec_step]

      blocks_last_lstm_state = blocks_frame_probs_outputs["decoder_sequencegenerator__sequencegenerator_generate_states"]
      blocks_last_lstm_cells = blocks_frame_probs_outputs["decoder_sequencegenerator__sequencegenerator_generate_cells"]
      assert blocks_last_lstm_state.shape == (beam_size, last_lstm_state.shape[0])
      assert_almost_equal(blocks_last_lstm_state[0], last_lstm_state, decimal=5)
      assert_almost_equal(blocks_last_lstm_cells[0], last_lstm_cells, decimal=5)
      our_last_lstm_cells = our_dec_frame_outputs["prev:s.extra.state"][0]
      our_last_lstm_state = our_dec_frame_outputs["prev:s.extra.state"][1]
      assert our_last_lstm_state.shape == our_last_lstm_cells.shape == (beam_size, last_lstm_state.shape[0])
      assert_almost_equal(our_last_lstm_state[0], last_lstm_state, decimal=5)
      assert_almost_equal(our_last_lstm_cells[0], last_lstm_cells, decimal=5)
      our_last_s = our_dec_frame_outputs["prev:s.output"]
      assert our_last_s.shape == (beam_size, last_lstm_state.shape[0])
      assert_almost_equal(our_last_s[0], last_lstm_state, decimal=5)

      blocks_last_accum_weights = blocks_frame_probs_outputs["decoder_sequencegenerator__sequencegenerator_generate_accumulated_weights"]
      assert blocks_last_accum_weights.shape == (beam_size, seq_len)
      assert_almost_equal(blocks_last_accum_weights[0], last_accumulated_weights, decimal=5)
      our_last_accum_weights = our_dec_frame_outputs["prev:accum_att_weights.output"]
      assert our_last_accum_weights.shape == (beam_size, seq_len if dec_step > 0 else 1, 1)
      if dec_step > 0:
        assert_almost_equal(our_last_accum_weights[0, :, 0], last_accumulated_weights, decimal=4)
      else:
        assert_almost_equal(our_last_accum_weights[0, 0, 0], last_accumulated_weights.sum(), decimal=4)

      energy_sum = numpy.copy(blocks_enc_ctx_out[:, 0])  # (T,enc-ctx-dim)
      weight_feedback = numpy.dot(last_accumulated_weights[:, None], blocks_params["decoder/sequencegenerator/att_trans/attention/sum_alignment_transformer.W"])
      energy_sum += weight_feedback
      transformed_states = numpy.dot(last_lstm_state[None, :], blocks_params["decoder/sequencegenerator/att_trans/attention/state_trans/transform_states.W"])
      transformed_cells = numpy.dot(last_lstm_cells[None, :], blocks_params["decoder/sequencegenerator/att_trans/attention/state_trans/transform_cells.W"])
      energy_sum += transformed_states + transformed_cells
      assert energy_sum.shape == (seq_len, blocks_enc_ctx_out.shape[-1])
      blocks_energy_sum_tanh = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention_energy_comp_tanh__tanh_apply_output"]
      assert blocks_energy_sum_tanh.shape == (seq_len, beam_size, energy_sum.shape[-1])
      assert_almost_equal(blocks_energy_sum_tanh[:, 0], numpy.tanh(energy_sum), decimal=5)
      assert_equal(our_dec_frame_outputs["weight_feedback.output"].shape, (beam_size, seq_len if dec_step > 0 else 1, blocks_enc_ctx_out.shape[-1]))
      assert_equal(our_dec_frame_outputs["prev_s_transformed.output"].shape, (beam_size, blocks_enc_ctx_out.shape[-1]))
      our_energy_sum = our_dec_frame_outputs["energy_in.output"]
      assert our_energy_sum.shape == (beam_size, seq_len, blocks_enc_ctx_out.shape[-1])
      assert_almost_equal(our_energy_sum[0], energy_sum, decimal=4)
      blocks_energy = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention_energy_comp__energy_comp_apply_output"]
      assert blocks_energy.shape == (seq_len, beam_size, 1)
      energy = numpy.dot(numpy.tanh(energy_sum), blocks_params["decoder/sequencegenerator/att_trans/attention/energy_comp/linear.W"])
      assert energy.shape == (seq_len, 1)
      assert_almost_equal(blocks_energy[:, 0], energy, decimal=4)
      our_energy = our_dec_frame_outputs["energy.output"]
      assert our_energy.shape == (beam_size, seq_len, 1)
      assert_almost_equal(our_energy[0], energy, decimal=4)
      weights = softmax(energy[:, 0])
      assert weights.shape == (seq_len,)
      our_weights = our_dec_frame_outputs["att_weights.output"]
      assert our_weights.shape == (beam_size, seq_len, 1)
      assert_almost_equal(our_weights[0, :, 0], weights, decimal=4)
      accumulated_weights = last_accumulated_weights + weights / (2.0 * fertility)
      assert accumulated_weights.shape == (seq_len,)
      #blocks_accumulated_weights = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention__attention_take_glimpses_accumulated_weights"]
      #assert blocks_accumulated_weights.shape == (beam_size, seq_len)
      #assert_almost_equal(blocks_accumulated_weights[0], accumulated_weights, decimal=5)
      blocks_weights = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention__attention_compute_weights_output_0"]
      assert blocks_weights.shape == (seq_len, beam_size)
      assert_almost_equal(weights, blocks_weights[:, 0], decimal=4)
      our_accum_weights = our_dec_frame_outputs["accum_att_weights.output"]
      assert our_accum_weights.shape == (beam_size, seq_len, 1)
      weighted_avg = (weights[:, None] * blocks_encoder_out[:, 0]).sum(axis=0)  # att in our
      assert weighted_avg.shape == (blocks_encoder_out.shape[-1],)
      blocks_weighted_avg = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention__attention_compute_weighted_averages_output_0"]
      assert blocks_weighted_avg.shape == (beam_size, blocks_encoder_out.shape[-1])
      assert_almost_equal(blocks_weighted_avg[0], weighted_avg, decimal=4)
      our_att = our_dec_frame_outputs["att.output"]
      assert our_att.shape == (beam_size, blocks_encoder_out.shape[-1])
      assert_almost_equal(our_att[0], weighted_avg, decimal=4)

      blocks_last_output = blocks_frame_probs_outputs["decoder_sequencegenerator__sequencegenerator_generate_outputs"]
      assert blocks_last_output.shape == (beam_size,)
      assert max(blocks_last_output[0], 0) == last_output
      last_target_embed = dec_lookup[last_output]
      if dec_step == 0:
        last_target_embed = numpy.zeros_like(last_target_embed)
      our_last_target_embed = our_dec_frame_outputs["prev:target_embed.output"]
      assert our_last_target_embed.shape == (beam_size, dec_lookup.shape[-1])
      assert_almost_equal(our_last_target_embed[0], last_target_embed, decimal=4)

      readout_in_state = numpy.dot(last_lstm_state, blocks_params["decoder/sequencegenerator/readout/merge/transform_states.W"])
      blocks_trans_state = blocks_frame_probs_outputs["decoder_sequencegenerator_readout_merge__merge_apply_states"]
      assert blocks_trans_state.shape == (beam_size, last_lstm_state.shape[0])
      assert_almost_equal(blocks_trans_state[0], readout_in_state, decimal=4)
      readout_in_feedback = numpy.dot(last_target_embed, blocks_params["decoder/sequencegenerator/readout/merge/transform_feedback.W"])
      blocks_trans_feedback = blocks_frame_probs_outputs["decoder_sequencegenerator_readout_merge__merge_apply_feedback"]
      assert blocks_trans_feedback.shape == (beam_size, readout_in_feedback.shape[0])
      assert_almost_equal(blocks_trans_feedback[0], readout_in_feedback, decimal=4)
      readout_in_weighted_avg = numpy.dot(weighted_avg, blocks_params["decoder/sequencegenerator/readout/merge/transform_weighted_averages.W"])
      blocks_trans_weighted_avg = blocks_frame_probs_outputs["decoder_sequencegenerator_readout_merge__merge_apply_weighted_averages"]
      assert blocks_trans_weighted_avg.shape == (beam_size, readout_in_weighted_avg.shape[0])
      assert_almost_equal(blocks_trans_weighted_avg[0], readout_in_weighted_avg, decimal=4)
      readout_in = readout_in_state + readout_in_feedback + readout_in_weighted_avg
      blocks_readout_in = blocks_frame_probs_outputs["decoder_sequencegenerator_readout_merge__merge_apply_output"]
      assert blocks_readout_in.shape == (beam_size, readout_in.shape[0])
      assert_almost_equal(blocks_readout_in[0], readout_in, decimal=4)
      readout_in += blocks_params["decoder/sequencegenerator/readout/initializablefeedforwardsequence/maxout_bias.b"]
      assert readout_in.shape == (blocks_params["decoder/sequencegenerator/readout/initializablefeedforwardsequence/maxout_bias.b"].shape[0],)
      our_readout_in = our_dec_frame_outputs["readout_in.output"]
      assert our_readout_in.shape == (beam_size, readout_in.shape[0])
      assert_almost_equal(our_readout_in[0], readout_in, decimal=4)
      readout = readout_in.reshape((readout_in.shape[0] // 2, 2)).max(axis=1)
      our_readout = our_dec_frame_outputs["readout.output"]
      assert our_readout.shape == (beam_size, readout.shape[0])
      assert_almost_equal(our_readout[0], readout, decimal=4)
      prob_logits = numpy.dot(readout, blocks_params["decoder/sequencegenerator/readout/initializablefeedforwardsequence/softmax1.W"]) + \
        blocks_params["decoder/sequencegenerator/readout/initializablefeedforwardsequence/softmax1.b"]
      assert prob_logits.ndim == 1
      blocks_prob_logits = blocks_frame_probs_outputs["decoder_sequencegenerator_readout__readout_readout_output_0"]
      assert blocks_prob_logits.shape == (beam_size, prob_logits.shape[0])
      assert_almost_equal(blocks_prob_logits[0], prob_logits, decimal=4)
      output_prob = softmax(prob_logits)
      log_output_prob = log_softmax(prob_logits)
      assert_almost_equal(numpy.log(output_prob), log_output_prob, decimal=4)
      our_output_prob = our_dec_frame_outputs["output_prob.output"]
      assert our_output_prob.shape == (beam_size, output_prob.shape[0])
      assert_almost_equal(our_output_prob[0], output_prob, decimal=4)
      blocks_nlog_prob = blocks_frame_probs_outputs["logprobs"]
      assert blocks_nlog_prob.shape == (beam_size, output_prob.shape[0])
      assert_almost_equal(blocks_nlog_prob[0], -log_output_prob, decimal=4)
      assert_almost_equal(our_dec_search_frame_outputs["scores_in_orig"][0], output_prob, decimal=4)
      assert_almost_equal(blocks_search_frame[b'logprobs'][0], -log_output_prob, decimal=4)
      #for b in range(beam_size):
      #  assert_almost_equal(-numpy.log(our_output_prob[b]), blocks_frame_probs_outputs["logprobs"][b], decimal=4)
      ref_output = numpy.argmax(output_prob)
      # Note: Don't take the readout.emit outputs. They are randomly sampled.
      blocks_dec_output = blocks_search_frame[b'outputs']
      assert blocks_dec_output.shape == (beam_size,)
      our_dec_output = our_dec_frame_outputs["output.output"]
      assert our_dec_output.shape == (beam_size,)
      print("Frame %i: Ref best greedy output symbol: %i" % (dec_step, int(ref_output)))
      print("Blocks labels:", blocks_dec_output.tolist())
      print("Our labels:", our_dec_output.tolist())
      # Well, the following two could be not true if all the other beams have much better scores,
      # but this is unlikely.
      assert ref_output in blocks_dec_output
      assert ref_output in our_dec_output
      if dec_step == 0:
        # This assumes that the results are ordered by score which might not be true (see tf.nn.top_k).
        assert blocks_dec_output[0] == our_dec_output[0] == ref_output
      # We assume that the best is the same. Note that this also might not be true if there are two equally best scores.
      # It also assumes that it's ordered by the score which also might not be true (see tf.nn.top_k).
      # For the same reason, the remaining list and entries might also not perfectly match.
      assert our_dec_output[0] == blocks_dec_output[0]
      # Just follow the first beam.
      ref_output = blocks_dec_output[0]
      assert our_dec_search_frame_outputs["src_beam_idxs"].shape == (1, beam_size)
      assert our_dec_search_frame_outputs["scores"].shape == (1, beam_size)
      print("Blocks src_beam_idxs:", blocks_search_frame[b'indexes'].tolist())
      print("Our src_beam_idxs:", our_dec_search_frame_outputs["src_beam_idxs"][0].tolist())
      print("Blocks scores:", blocks_search_frame[b'chosen_costs'].tolist())
      print("Our scores:", our_dec_search_frame_outputs["scores"][0].tolist())
      if list(our_dec_search_frame_outputs["src_beam_idxs"][0]) != list(blocks_search_frame[b'indexes']):
        print("Warning, beams do not match.")
        print("Blocks scores base:", blocks_search_frame[b'scores_base'].flatten().tolist())
        print("Our scores base:", our_dec_search_frame_outputs["scores_base"].flatten().tolist())
        #print("Blocks score in orig top k:", sorted(blocks_search_frame[b'logprobs'].flatten())[:beam_size])
        #print("Our score in orig top k:", sorted(-numpy.log(our_dec_search_frame_outputs["scores_in_orig"].flatten()))[:beam_size])
        print("Blocks score in top k:", sorted((blocks_search_frame[b'logprobs'] * blocks_search_log[dec_step - 1][b'mask'][:, None]).flatten())[:beam_size])
        print("Our score in top k:", sorted(-our_dec_search_frame_outputs["scores_in"].flatten())[:beam_size])
        blocks_scores_combined = blocks_search_frame[b'next_costs']
        our_scores_combined = our_dec_search_frame_outputs["scores_combined"]
        print("Blocks scores combined top k:", sorted(blocks_scores_combined.flatten())[:beam_size])
        print("Our neg scores combined top k:", sorted(-our_scores_combined.flatten())[:beam_size])
        #raise Exception("beams mismatch")
      assert our_dec_search_frame_outputs["src_beam_idxs"][0][0] == blocks_search_frame[b'indexes'][0]
      beam_idx = our_dec_search_frame_outputs["src_beam_idxs"][0][0]
      if beam_idx != 0:
        print("Selecting different beam: %i." % beam_idx)
        # Just overwrite the needed states by Blocks outputs.
        accumulated_weights = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_attention__attention_take_glimpses_accumulated_weights"][0]
        weighted_avg = blocks_frame_state_outputs["decoder_sequencegenerator__sequencegenerator_generate_weighted_averages"][0]
        last_lstm_state = blocks_frame_state_outputs["decoder_sequencegenerator__sequencegenerator_generate_states"][0]
        last_lstm_cells = blocks_frame_state_outputs["decoder_sequencegenerator__sequencegenerator_generate_cells"][0]

      # From now on, use blocks_frame_state_outputs instead of blocks_frame_probs_outputs because
      # it will have the beam reordered.
      blocks_target_emb = blocks_frame_state_outputs["decoder_sequencegenerator_fork__fork_apply_feedback_decoder_input"]
      assert blocks_target_emb.shape == (beam_size, dec_lookup.shape[1])
      target_embed = dec_lookup[ref_output]
      assert target_embed.shape == (dec_lookup.shape[1],)
      assert_almost_equal(blocks_target_emb[0], target_embed)

      feedback_to_decoder = numpy.dot(target_embed, blocks_params["decoder/sequencegenerator/att_trans/feedback_to_decoder/fork_inputs.W"])
      context_to_decoder = numpy.dot(weighted_avg, blocks_params["decoder/sequencegenerator/att_trans/context_to_decoder/fork_inputs.W"])
      lstm_z = feedback_to_decoder + context_to_decoder
      assert lstm_z.shape == feedback_to_decoder.shape == context_to_decoder.shape == (last_lstm_state.shape[-1] * 4,)
      blocks_feedback_to_decoder = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_feedback_to_decoder__feedback_to_decoder_apply_inputs"]
      blocks_context_to_decoder = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_context_to_decoder__context_to_decoder_apply_inputs"]
      assert blocks_feedback_to_decoder.shape == blocks_context_to_decoder.shape == (beam_size, last_lstm_state.shape[-1] * 4)
      assert_almost_equal(blocks_feedback_to_decoder[0], feedback_to_decoder, decimal=4)
      assert_almost_equal(blocks_context_to_decoder[0], context_to_decoder, decimal=4)
      lstm_state, lstm_cells = calc_raw_lstm(
        lstm_z, blocks_params=blocks_params,
        prefix="decoder/sequencegenerator/att_trans/lstm_decoder.",
        last_state=last_lstm_state, last_cell=last_lstm_cells)
      assert lstm_state.shape == last_lstm_state.shape == lstm_cells.shape == last_lstm_cells.shape
      blocks_lstm_state = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_lstm_decoder__lstm_decoder_apply_states"]
      blocks_lstm_cells = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_lstm_decoder__lstm_decoder_apply_cells"]
      assert blocks_lstm_state.shape == blocks_lstm_cells.shape == (beam_size, last_lstm_state.shape[-1])
      assert_almost_equal(blocks_lstm_state[0], lstm_state, decimal=4)
      assert_almost_equal(blocks_lstm_cells[0], lstm_cells, decimal=4)
      our_lstm_cells = our_dec_frame_outputs["s.extra.state"][0]
      our_lstm_state = our_dec_frame_outputs["s.extra.state"][1]
      assert our_lstm_state.shape == our_lstm_cells.shape == (beam_size, lstm_state.shape[0])
      assert_almost_equal(our_lstm_state[0], lstm_state, decimal=4)
      assert_almost_equal(our_lstm_cells[0], lstm_cells, decimal=4)
      our_s = our_dec_frame_outputs["s.output"]
      assert our_s.shape == (beam_size, lstm_state.shape[0])
      assert_almost_equal(our_s[0], lstm_state, decimal=4)

      last_accumulated_weights = accumulated_weights
      last_lstm_state = lstm_state
      last_lstm_cells = lstm_cells
      last_output = ref_output
      if last_output == 0:
        print("Sequence finished, seq len %i." % dec_step)
        dec_seq_len = dec_step
        break
    assert dec_seq_len > 0
    print("All outputs seem to match.")
  else:
    print("blocks_debug_dump_output not specified. It will not compare the model outputs." % blocks_debug_dump_output)

  if dry_run:
    print("Dry-run, not saving model.")
  else:
    rnn.engine.save_model(our_model_fn)
  print("Finished importing.")
def main():
    argparser = ArgumentParser(description=__doc__,
                               formatter_class=RawTextHelpFormatter)
    argparser.add_argument("--model",
                           required=True,
                           help="or config, or setup")
    argparser.add_argument("--epoch", required=True, type=int)
    argparser.add_argument("--prior",
                           help="none, fixed, softmax (default: none)")
    argparser.add_argument("--prior_scale", type=float, default=1.0)
    argparser.add_argument("--am_scale", type=float, default=1.0)
    argparser.add_argument("--tdp_scale", type=float, default=1.0)
    args = argparser.parse_args()

    cfg_fn = args.model
    if "/" not in cfg_fn:
        cfg_fn = "config-train/%s.config" % cfg_fn
    assert os.path.exists(cfg_fn)
    setup_name = os.path.splitext(os.path.basename(cfg_fn))[0]
    setup_dir = "data-train/%s" % setup_name
    assert os.path.exists(setup_dir)
    Globals.setup_name = setup_name
    Globals.setup_dir = setup_dir
    Globals.epoch = args.epoch

    config_update["epoch"] = args.epoch
    config_update["load_epoch"] = args.epoch
    config_update["model"] = "%s/net-model/network" % setup_dir

    import rnn
    rnn.init(configFilename=cfg_fn,
             config_updates=config_update,
             extra_greeting="calc full sum score.")
    Globals.engine = rnn.engine
    Globals.config = rnn.config
    Globals.dataset = rnn.dev_data

    assert Globals.engine and Globals.config and Globals.dataset
    # This will init the network, load the params, etc.
    Globals.engine.init_train_from_config(config=Globals.config,
                                          dev_data=Globals.dataset)

    # Do not modify the network here. Not needed.
    softmax_prior = get_softmax_prior()

    prior = args.prior or "none"
    if prior == "none":
        prior_filename = None
    elif prior == "softmax":
        prior_filename = softmax_prior
    elif prior == "fixed":
        prior_filename = "dependencies/prior-fixed-f32.xml"
    else:
        raise Exception("invalid prior %r" % prior)
    print("using prior:", prior)
    if prior_filename:
        assert os.path.exists(prior_filename)
        check_valid_prior(prior_filename)

    print("Do the stuff...")
    print("Reinit dataset.")
    Globals.dataset.init_seq_order(epoch=args.epoch)

    network_update["out_fullsum_scores"]["eval_locals"][
        "am_scale"] = args.am_scale
    network_update["out_fullsum_scores"]["eval_locals"][
        "prior_scale"] = args.prior_scale
    network_update["out_fullsum_bw"]["tdp_scale"] = args.tdp_scale
    if prior_filename:
        network_update["out_fullsum_prior"][
            "init"] = "load_txt_file(%r)" % prior_filename
    else:
        network_update["out_fullsum_prior"]["init"] = 0
    from copy import deepcopy
    Globals.config.typed_dict["network"] = deepcopy(
        Globals.config.typed_dict["network"])
    Globals.config.typed_dict["network"].update(network_update)
    # Reinit the network, and copy over params.
    from Pretrain import pretrainFromConfig
    pretrain = pretrainFromConfig(
        Globals.config)  # reinit Pretrain topologies if used
    if pretrain:
        new_network_desc = pretrain.get_network_json_for_epoch(Globals.epoch)
    else:
        new_network_desc = Globals.config.typed_dict["network"]
    assert "output_fullsum" in new_network_desc
    print("Init new network.")
    Globals.engine.maybe_init_new_network(new_network_desc)

    print("Calc scores.")
    calc_fullsum_scores(meta=dict(prior=prior,
                                  prior_scale=args.prior_scale,
                                  am_scale=args.am_scale,
                                  tdp_scale=args.tdp_scale))

    rnn.finalize()
    print("Bye.")
Exemple #12
0
def main():
  rnn.init(
    commandLineOptions=sys.argv[1:],
    config_updates={
      "task": "nop", "log": None, "device": "cpu",
      "allow_random_model_init": True,
      "debug_add_check_numerics_on_output": False},
    extra_greeting="Import Blocks MT model.")
  assert Util.BackendEngine.is_tensorflow_selected()
  config = rnn.config

  # Load Blocks MT model params.
  if not config.has("blocks_mt_model"):
    print("Please provide the option blocks_mt_model.")
    sys.exit(1)
  blocks_mt_model_fn = config.value("blocks_mt_model", "")
  assert blocks_mt_model_fn
  assert os.path.exists(blocks_mt_model_fn)
  if os.path.isdir(blocks_mt_model_fn):
    blocks_mt_model_fn += "/params.npz"
    assert os.path.exists(blocks_mt_model_fn)

  dry_run = config.bool("dry_run", False)
  if dry_run:
    our_model_fn = None
    print("Dry-run, will not save model.")
  else:
    our_model_fn = config.value('model', "returnn-model") + ".imported"
    print("Will save Returnn model as %s." % our_model_fn)
    assert os.path.exists(os.path.dirname(our_model_fn) or "."), "model-dir does not exist"
    assert not os.path.exists(our_model_fn + Util.get_model_filename_postfix()), "model-file already exists"

  blocks_mt_model = numpy.load(blocks_mt_model_fn)
  assert isinstance(blocks_mt_model, numpy.lib.npyio.NpzFile), "did not expect type %r in file %r" % (
    type(blocks_mt_model), blocks_mt_model_fn)
  print("Params found in Blocks model:")
  blocks_params = {}  # type: dict[str,numpy.ndarray]
  blocks_params_hierarchy = {}  # type: dict[str,dict[str]]
  blocks_total_num_params = 0
  for key in sorted(blocks_mt_model.keys()):
    value = blocks_mt_model[key]
    key = key.replace("-", "/")
    assert key[0] == "/"
    key = key[1:]
    blocks_params[key] = value
    print("  %s: %s, %s" % (key, value.shape, value.dtype))
    blocks_total_num_params += numpy.prod(value.shape)
    d = blocks_params_hierarchy
    for part in key.split("/"):
      d = d.setdefault(part, {})
  print("Blocks total num params: %i" % blocks_total_num_params)

  # Init our network structure.
  from TFNetworkRecLayer import _SubnetworkRecCell
  _SubnetworkRecCell._debug_out = []  # enable for debugging intermediate values below
  ChoiceLayer._debug_out = []  # also for debug outputs of search
  rnn.engine.use_search_flag = True  # construct the net as in search
  rnn.engine.init_network_from_config()
  print("Our network model params:")
  our_params = {}  # type: dict[str,tf.Variable]
  our_total_num_params = 0
  for v in rnn.engine.network.get_params_list():
    key = v.name[:-2]
    our_params[key] = v
    print("  %s: %s, %s" % (key, v.shape, v.dtype.base_dtype.name))
    our_total_num_params += numpy.prod(v.shape.as_list())
  print("Our total num params: %i" % our_total_num_params)

  # Now matching...
  blocks_used_params = set()  # type: set[str]
  our_loaded_params = set()  # type: set[str]

  def import_var(our_var, blocks_param):
    """
    :param tf.Variable our_var:
    :param str|numpy.ndarray blocks_param:
    """
    assert isinstance(our_var, tf.Variable)
    if isinstance(blocks_param, str):
      blocks_param = load_blocks_var(blocks_param)
    assert isinstance(blocks_param, numpy.ndarray)
    assert_equal(tuple(our_var.shape.as_list()), blocks_param.shape)
    our_loaded_params.add(our_var.name[:-2])
    our_var.load(blocks_param, session=rnn.engine.tf_session)

  def load_blocks_var(blocks_param_name):
    """
    :param str blocks_param_name:
    :rtype: numpy.ndarray
    """
    assert isinstance(blocks_param_name, str)
    assert blocks_param_name in blocks_params
    blocks_used_params.add(blocks_param_name)
    return blocks_params[blocks_param_name]

  enc_name = "bidirectionalencoder"
  enc_embed_name = "EncoderLookUp0.W"
  assert enc_name in blocks_params_hierarchy
  assert enc_embed_name in blocks_params_hierarchy[enc_name]  # input embedding
  num_encoder_layers = max([
    int(re.match(".*([0-9]+)", s).group(1))
    for s in blocks_params_hierarchy[enc_name]
    if s.startswith("EncoderBidirectionalLSTM")])
  blocks_input_dim, blocks_input_embed_dim = blocks_params["%s/%s" % (enc_name, enc_embed_name)].shape
  print("Blocks input dim: %i, embed dim: %i" % (blocks_input_dim, blocks_input_embed_dim))
  print("Blocks num encoder layers: %i" % num_encoder_layers)
  expected_enc_entries = (
    ["EncoderLookUp0.W"] +
    ["EncoderBidirectionalLSTM%i" % i for i in range(1, num_encoder_layers + 1)])
  assert_equal(set(expected_enc_entries), set(blocks_params_hierarchy[enc_name].keys()))

  our_input_layer = find_our_input_embed_layer()
  assert our_input_layer.input_data.dim == blocks_input_dim
  assert our_input_layer.output.dim == blocks_input_embed_dim
  assert not our_input_layer.with_bias
  import_var(our_input_layer.params["W"], "%s/%s" % (enc_name, enc_embed_name))

  dec_name = "decoder/sequencegenerator"
  dec_hierarchy_base = get_in_hierarchy(dec_name, blocks_params_hierarchy)
  assert_equal(set(dec_hierarchy_base.keys()), {"att_trans", "readout"})
  dec_embed_name = "readout/lookupfeedbackwmt15/lookuptable.W"
  get_in_hierarchy(dec_embed_name, dec_hierarchy_base)  # check

  for i in range(num_encoder_layers):
    # Assume standard LSTMCell.
    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    # lstm_matrix = self._linear1([inputs, m_prev])
    # i, j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=4, axis=1)
    # bias (4*in), kernel (in+out,4*out), w_(f|i|o)_diag (out)
    # prefix: rec/rnn/lstm_cell
    # Blocks: gate-in, gate-forget, next-in, gate-out
    for direction in ("fwd", "bwd"):
      our_layer = get_network().layers["lstm%i_%s" % (i, direction[:2])]
      blocks_prefix = "bidirectionalencoder/EncoderBidirectionalLSTM%i" % (i + 1,)
      # (in,out*4), (out*4,)
      W_in, b = [load_blocks_var(
        "%s/%s_fork/fork_inputs.%s" % (blocks_prefix, {"bwd": "back", "fwd": "fwd"}[direction], p))
        for p in ("W", "b")]
      W_re = load_blocks_var(
        "%s/bidirectionalseparateparameters/%s.W_state" % (blocks_prefix, {"fwd": "forward", "bwd": "backward"}[direction]))
      W = numpy.concatenate([W_in, W_re], axis=0)
      b = lstm_vec_blocks_to_tf(b)
      W = lstm_vec_blocks_to_tf(W)
      import_var(our_layer.params["rnn/lstm_cell/bias"], b)
      import_var(our_layer.params["rnn/lstm_cell/kernel"], W)
      import_var(our_layer.params["initial_c"], "%s/bidirectionalseparateparameters/%s.initial_cells" % (blocks_prefix, {"fwd": "forward", "bwd": "backward"}[direction]))
      import_var(our_layer.params["initial_h"], "%s/bidirectionalseparateparameters/%s.initial_state" % (blocks_prefix, {"fwd": "forward", "bwd": "backward"}[direction]))
      for s1, s2 in [("W_cell_to_in", "w_i_diag"), ("W_cell_to_forget", "w_f_diag"), ("W_cell_to_out", "w_o_diag")]:
        import_var(our_layer.params["rnn/lstm_cell/%s" % s2], "%s/bidirectionalseparateparameters/%s.%s" % (blocks_prefix, {"fwd": "forward", "bwd": "backward"}[direction], s1))
  import_var(get_network().layers["enc_ctx"].params["W"], "decoder/sequencegenerator/att_trans/attention/encoder_state_transformer.W")
  import_var(get_network().layers["enc_ctx"].params["b"], "decoder/sequencegenerator/att_trans/attention/encoder_state_transformer.b")
  import_var(our_params["output/rec/s/initial_c"], "decoder/sequencegenerator/att_trans/lstm_decoder.initial_cells")
  import_var(our_params["output/rec/s/initial_h"], "decoder/sequencegenerator/att_trans/lstm_decoder.initial_state")
  import_var(our_params["output/rec/weight_feedback/W"], "decoder/sequencegenerator/att_trans/attention/sum_alignment_transformer.W")
  import_var(our_params["output/rec/target_embed/W"], "decoder/sequencegenerator/readout/lookupfeedbackwmt15/lookuptable.W")
  import_var(our_params["fertility/W"], "decoder/sequencegenerator/att_trans/attention/fertility_transformer.W")
  import_var(our_params["output/rec/energy/W"], "decoder/sequencegenerator/att_trans/attention/energy_comp/linear.W")
  prev_s_trans_W_states = load_blocks_var("decoder/sequencegenerator/att_trans/attention/state_trans/transform_states.W")
  prev_s_trans_W_cells = load_blocks_var("decoder/sequencegenerator/att_trans/attention/state_trans/transform_cells.W")
  prev_s_trans_W = numpy.concatenate([prev_s_trans_W_cells, prev_s_trans_W_states], axis=0)
  import_var(our_params["output/rec/prev_s_transformed/W"], prev_s_trans_W)
  import_var(our_params["output/rec/s/rec/lstm_cell/bias"], numpy.zeros(our_params["output/rec/s/rec/lstm_cell/bias"].shape))
  dec_lstm_kernel_in_feedback = load_blocks_var("decoder/sequencegenerator/att_trans/feedback_to_decoder/fork_inputs.W")
  dec_lstm_kernel_in_ctx = load_blocks_var("decoder/sequencegenerator/att_trans/context_to_decoder/fork_inputs.W")
  dec_lstm_kernel_re = load_blocks_var("decoder/sequencegenerator/att_trans/lstm_decoder.W_state")
  dec_lstm_kernel = numpy.concatenate([dec_lstm_kernel_in_feedback, dec_lstm_kernel_in_ctx, dec_lstm_kernel_re], axis=0)
  dec_lstm_kernel = lstm_vec_blocks_to_tf(dec_lstm_kernel)
  import_var(our_params["output/rec/s/rec/lstm_cell/kernel"], dec_lstm_kernel)
  for s1, s2 in [("W_cell_to_in", "w_i_diag"), ("W_cell_to_forget", "w_f_diag"), ("W_cell_to_out", "w_o_diag")]:
    import_var(our_params["output/rec/s/rec/lstm_cell/%s" % s2], "decoder/sequencegenerator/att_trans/lstm_decoder.%s" % s1)
  readout_in_W_states = load_blocks_var("decoder/sequencegenerator/readout/merge/transform_states.W")
  readout_in_W_feedback = load_blocks_var("decoder/sequencegenerator/readout/merge/transform_feedback.W")
  readout_in_W_att = load_blocks_var("decoder/sequencegenerator/readout/merge/transform_weighted_averages.W")
  readout_in_W = numpy.concatenate([readout_in_W_states, readout_in_W_feedback, readout_in_W_att], axis=0)
  import_var(our_params["output/rec/readout_in/W"], readout_in_W)
  import_var(our_params["output/rec/readout_in/b"], "decoder/sequencegenerator/readout/initializablefeedforwardsequence/maxout_bias.b")
  import_var(our_params["output/rec/output_prob/W"], "decoder/sequencegenerator/readout/initializablefeedforwardsequence/softmax1.W")
  import_var(our_params["output/rec/output_prob/b"], "decoder/sequencegenerator/readout/initializablefeedforwardsequence/softmax1.b")

  print("Not initialized own params:")
  count = 0
  for key, v in sorted(our_params.items()):
    if key in our_loaded_params:
      continue
    print("  %s: %s, %s" % (key, v.shape, v.dtype.base_dtype.name))
    count += 1
  if not count:
    print("  None.")
  print("Not used Blocks params:")
  count = 0
  for key, value in sorted(blocks_params.items()):
    if key in blocks_used_params:
      continue
    print("  %s: %s, %s" % (key, value.shape, value.dtype))
    count += 1
  if not count:
    print("  None.")
  print("Done.")

  blocks_debug_dump_output = config.value("blocks_debug_dump_output", None)
  if blocks_debug_dump_output:
    print("Will read Blocks debug dump output from %r and compare with Returnn outputs." % blocks_debug_dump_output)
    blocks_initial_outputs = numpy.load("%s/initial_states_data.0.npz" % blocks_debug_dump_output)
    blocks_search_log = pickle.load(open("%s/search.log.pkl" % blocks_debug_dump_output, "rb"), encoding="bytes")
    blocks_search_log = {d[b"step"]: d for d in blocks_search_log}
    input_seq = blocks_initial_outputs["input"]
    beam_size, seq_len = input_seq.shape
    input_seq = input_seq[0]  # all the same, select beam 0
    assert isinstance(input_seq, numpy.ndarray)
    print("Debug input seq: %s" % input_seq.tolist())
    from GeneratingDataset import StaticDataset
    dataset = StaticDataset(
      data=[{"data": input_seq}],
      output_dim={"data": get_network().extern_data.get_default_input_data().get_kwargs()})
    dataset.init_seq_order(epoch=0)
    extract_output_dict = {
      "enc_src_emb": get_network().layers["source_embed"].output.get_placeholder_as_batch_major(),
      "encoder": get_network().layers["encoder"].output.get_placeholder_as_batch_major(),
      "enc_ctx": get_network().layers["enc_ctx"].output.get_placeholder_as_batch_major(),
      "output": get_network().layers["output"].output.get_placeholder_as_batch_major()
    }
    from TFNetworkLayer import concat_sources
    for i in range(num_encoder_layers):
      extract_output_dict["enc_layer_%i" % i] = concat_sources(
        [get_network().layers["lstm%i_fw" % i], get_network().layers["lstm%i_bw" % i]]
      ).get_placeholder_as_batch_major()
    extract_output_dict["enc_layer_0_fwd"] = get_network().layers["lstm0_fw"].output.get_placeholder_as_batch_major()
    our_output = rnn.engine.run_single(
      dataset=dataset, seq_idx=0, output_dict=extract_output_dict)
    blocks_out = blocks_initial_outputs["bidirectionalencoder_EncoderLookUp0__EncoderLookUp0_apply_output"]
    our_out = our_output["enc_src_emb"]
    print("our enc emb shape:", our_out.shape)
    print("Blocks enc emb shape:", blocks_out.shape)
    assert our_out.shape[:2] == (1, seq_len)
    assert blocks_out.shape[:2] == (seq_len, beam_size)
    assert our_out.shape[2] == blocks_out.shape[2]
    assert_almost_equal(our_out[0], blocks_out[:, 0], decimal=5)
    blocks_lstm0_out_ref = calc_lstm(blocks_out[:, 0], blocks_params)
    blocks_lstm0_out = blocks_initial_outputs["bidirectionalencoder_EncoderBidirectionalLSTM1_bidirectionalseparateparameters_forward__forward_apply_states"]
    our_lstm0_out = our_output["enc_layer_0_fwd"]
    assert blocks_lstm0_out.shape == (seq_len, beam_size) + blocks_lstm0_out_ref.shape
    assert our_lstm0_out.shape == (1, seq_len) + blocks_lstm0_out_ref.shape
    assert_almost_equal(blocks_lstm0_out[0, 0], blocks_lstm0_out_ref, decimal=6)
    print("Blocks LSTM0 frame 0 matched to ref calc.")
    assert_almost_equal(our_lstm0_out[0, 0], blocks_lstm0_out_ref, decimal=6)
    print("Our LSTM0 frame 0 matched to ref calc.")
    for i in range(num_encoder_layers):
      blocks_out = blocks_initial_outputs[
        "bidirectionalencoder_EncoderBidirectionalLSTM%i_bidirectionalseparateparameters__bidirectionalseparateparameters_apply_output_0" % (i + 1,)]
      our_out = our_output["enc_layer_%i" % i]
      print("our enc layer %i shape:" % i, our_out.shape)
      print("Blocks enc layer %i shape:" % i, blocks_out.shape)
      assert our_out.shape[:2] == (1, seq_len)
      assert blocks_out.shape[:2] == (seq_len, beam_size)
      assert our_out.shape[2] == blocks_out.shape[2]
      assert_almost_equal(our_out[0], blocks_out[:, 0], decimal=6)
    print("our encoder shape:", our_output["encoder"].shape)
    blocks_encoder_out = blocks_initial_outputs["bidirectionalencoder__bidirectionalencoder_apply_representation"]
    print("Blocks encoder shape:", blocks_encoder_out.shape)
    assert our_output["encoder"].shape[:2] == (1, seq_len)
    assert blocks_encoder_out.shape[:2] == (seq_len, beam_size)
    assert our_output["encoder"].shape[2] == blocks_encoder_out.shape[2]
    assert_almost_equal(our_output["encoder"][0], blocks_encoder_out[:, 0], decimal=6)
    blocks_first_frame_outputs = numpy.load("%s/next_states.0.npz" % blocks_debug_dump_output)
    blocks_enc_ctx_out = blocks_first_frame_outputs["decoder_sequencegenerator_att_trans_attention__attention_preprocess_preprocessed_attended"]
    our_enc_ctx_out = our_output["enc_ctx"]
    print("Blocks enc ctx shape:", blocks_enc_ctx_out.shape)
    assert blocks_enc_ctx_out.shape[:2] == (seq_len, beam_size)
    assert our_enc_ctx_out.shape[:2] == (1, seq_len)
    assert blocks_enc_ctx_out.shape[2:] == our_enc_ctx_out.shape[2:]
    assert_almost_equal(blocks_enc_ctx_out[:, 0], our_enc_ctx_out[0], decimal=5)
    fertility = numpy.dot(blocks_encoder_out[:, 0], blocks_params["decoder/sequencegenerator/att_trans/attention/fertility_transformer.W"])
    fertility = sigmoid(fertility)
    assert fertility.shape == (seq_len, 1)
    fertility = fertility[:, 0]
    assert fertility.shape == (seq_len,)
    our_dec_outputs = {v["step"]: v for v in _SubnetworkRecCell._debug_out}
    assert our_dec_outputs
    print("our dec frame keys:", sorted(our_dec_outputs[0].keys()))
    our_dec_search_outputs = {v["step"]: v for v in ChoiceLayer._debug_out}
    assert our_dec_search_outputs
    print("our dec search frame keys:", sorted(our_dec_search_outputs[0].keys()))
    print("Blocks search frame keys:", sorted(blocks_search_log[0].keys()))
    dec_lookup = blocks_params["decoder/sequencegenerator/readout/lookupfeedbackwmt15/lookuptable.W"]
    last_lstm_state = blocks_params["decoder/sequencegenerator/att_trans/lstm_decoder.initial_state"]
    last_lstm_cells = blocks_params["decoder/sequencegenerator/att_trans/lstm_decoder.initial_cells"]
    last_accumulated_weights = numpy.zeros((seq_len,), dtype="float32")
    last_output = 0
    dec_seq_len = 0
    for dec_step in range(100):
      blocks_frame_state_outputs_fn = "%s/next_states.%i.npz" % (blocks_debug_dump_output, dec_step)
      blocks_frame_probs_outputs_fn = "%s/logprobs.%i.npz" % (blocks_debug_dump_output, dec_step)
      if dec_step > 3:
        if not os.path.exists(blocks_frame_state_outputs_fn) or not os.path.exists(blocks_frame_probs_outputs_fn):
          print("Seq not ended yet but frame not found for step %i." % dec_step)
          break
      blocks_frame_state_outputs = numpy.load(blocks_frame_state_outputs_fn)
      blocks_frame_probs_outputs = numpy.load(blocks_frame_probs_outputs_fn)
      blocks_search_frame = blocks_search_log[dec_step]
      our_dec_frame_outputs = our_dec_outputs[dec_step]
      assert our_dec_frame_outputs["step"] == dec_step
      assert our_dec_frame_outputs[":i.output"].tolist() == [dec_step]
      our_dec_search_frame_outputs = our_dec_search_outputs[dec_step]

      blocks_last_lstm_state = blocks_frame_probs_outputs["decoder_sequencegenerator__sequencegenerator_generate_states"]
      blocks_last_lstm_cells = blocks_frame_probs_outputs["decoder_sequencegenerator__sequencegenerator_generate_cells"]
      assert blocks_last_lstm_state.shape == (beam_size, last_lstm_state.shape[0])
      assert_almost_equal(blocks_last_lstm_state[0], last_lstm_state, decimal=5)
      assert_almost_equal(blocks_last_lstm_cells[0], last_lstm_cells, decimal=5)
      our_last_lstm_cells = our_dec_frame_outputs["prev:s.extra.state"][0]
      our_last_lstm_state = our_dec_frame_outputs["prev:s.extra.state"][1]
      assert our_last_lstm_state.shape == our_last_lstm_cells.shape == (beam_size, last_lstm_state.shape[0])
      assert_almost_equal(our_last_lstm_state[0], last_lstm_state, decimal=5)
      assert_almost_equal(our_last_lstm_cells[0], last_lstm_cells, decimal=5)
      our_last_s = our_dec_frame_outputs["prev:s.output"]
      assert our_last_s.shape == (beam_size, last_lstm_state.shape[0])
      assert_almost_equal(our_last_s[0], last_lstm_state, decimal=5)

      blocks_last_accum_weights = blocks_frame_probs_outputs["decoder_sequencegenerator__sequencegenerator_generate_accumulated_weights"]
      assert blocks_last_accum_weights.shape == (beam_size, seq_len)
      assert_almost_equal(blocks_last_accum_weights[0], last_accumulated_weights, decimal=5)
      our_last_accum_weights = our_dec_frame_outputs["prev:accum_att_weights.output"]
      assert our_last_accum_weights.shape == (beam_size, seq_len if dec_step > 0 else 1, 1)
      if dec_step > 0:
        assert_almost_equal(our_last_accum_weights[0, :, 0], last_accumulated_weights, decimal=4)
      else:
        assert_almost_equal(our_last_accum_weights[0, 0, 0], last_accumulated_weights.sum(), decimal=4)

      energy_sum = numpy.copy(blocks_enc_ctx_out[:, 0])  # (T,enc-ctx-dim)
      weight_feedback = numpy.dot(last_accumulated_weights[:, None], blocks_params["decoder/sequencegenerator/att_trans/attention/sum_alignment_transformer.W"])
      energy_sum += weight_feedback
      transformed_states = numpy.dot(last_lstm_state[None, :], blocks_params["decoder/sequencegenerator/att_trans/attention/state_trans/transform_states.W"])
      transformed_cells = numpy.dot(last_lstm_cells[None, :], blocks_params["decoder/sequencegenerator/att_trans/attention/state_trans/transform_cells.W"])
      energy_sum += transformed_states + transformed_cells
      assert energy_sum.shape == (seq_len, blocks_enc_ctx_out.shape[-1])
      blocks_energy_sum_tanh = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention_energy_comp_tanh__tanh_apply_output"]
      assert blocks_energy_sum_tanh.shape == (seq_len, beam_size, energy_sum.shape[-1])
      assert_almost_equal(blocks_energy_sum_tanh[:, 0], numpy.tanh(energy_sum), decimal=5)
      assert_equal(our_dec_frame_outputs["weight_feedback.output"].shape, (beam_size, seq_len if dec_step > 0 else 1, blocks_enc_ctx_out.shape[-1]))
      assert_equal(our_dec_frame_outputs["prev_s_transformed.output"].shape, (beam_size, blocks_enc_ctx_out.shape[-1]))
      our_energy_sum = our_dec_frame_outputs["energy_in.output"]
      assert our_energy_sum.shape == (beam_size, seq_len, blocks_enc_ctx_out.shape[-1])
      assert_almost_equal(our_energy_sum[0], energy_sum, decimal=4)
      blocks_energy = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention_energy_comp__energy_comp_apply_output"]
      assert blocks_energy.shape == (seq_len, beam_size, 1)
      energy = numpy.dot(numpy.tanh(energy_sum), blocks_params["decoder/sequencegenerator/att_trans/attention/energy_comp/linear.W"])
      assert energy.shape == (seq_len, 1)
      assert_almost_equal(blocks_energy[:, 0], energy, decimal=4)
      our_energy = our_dec_frame_outputs["energy.output"]
      assert our_energy.shape == (beam_size, seq_len, 1)
      assert_almost_equal(our_energy[0], energy, decimal=4)
      weights = softmax(energy[:, 0])
      assert weights.shape == (seq_len,)
      our_weights = our_dec_frame_outputs["att_weights.output"]
      assert our_weights.shape == (beam_size, seq_len, 1)
      assert_almost_equal(our_weights[0, :, 0], weights, decimal=4)
      accumulated_weights = last_accumulated_weights + weights / (2.0 * fertility)
      assert accumulated_weights.shape == (seq_len,)
      #blocks_accumulated_weights = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention__attention_take_glimpses_accumulated_weights"]
      #assert blocks_accumulated_weights.shape == (beam_size, seq_len)
      #assert_almost_equal(blocks_accumulated_weights[0], accumulated_weights, decimal=5)
      blocks_weights = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention__attention_compute_weights_output_0"]
      assert blocks_weights.shape == (seq_len, beam_size)
      assert_almost_equal(weights, blocks_weights[:, 0], decimal=4)
      our_accum_weights = our_dec_frame_outputs["accum_att_weights.output"]
      assert our_accum_weights.shape == (beam_size, seq_len, 1)
      weighted_avg = (weights[:, None] * blocks_encoder_out[:, 0]).sum(axis=0)  # att in our
      assert weighted_avg.shape == (blocks_encoder_out.shape[-1],)
      blocks_weighted_avg = blocks_frame_probs_outputs["decoder_sequencegenerator_att_trans_attention__attention_compute_weighted_averages_output_0"]
      assert blocks_weighted_avg.shape == (beam_size, blocks_encoder_out.shape[-1])
      assert_almost_equal(blocks_weighted_avg[0], weighted_avg, decimal=4)
      our_att = our_dec_frame_outputs["att.output"]
      assert our_att.shape == (beam_size, blocks_encoder_out.shape[-1])
      assert_almost_equal(our_att[0], weighted_avg, decimal=4)

      blocks_last_output = blocks_frame_probs_outputs["decoder_sequencegenerator__sequencegenerator_generate_outputs"]
      assert blocks_last_output.shape == (beam_size,)
      assert max(blocks_last_output[0], 0) == last_output
      last_target_embed = dec_lookup[last_output]
      if dec_step == 0:
        last_target_embed = numpy.zeros_like(last_target_embed)
      our_last_target_embed = our_dec_frame_outputs["prev:target_embed.output"]
      assert our_last_target_embed.shape == (beam_size, dec_lookup.shape[-1])
      assert_almost_equal(our_last_target_embed[0], last_target_embed, decimal=4)

      readout_in_state = numpy.dot(last_lstm_state, blocks_params["decoder/sequencegenerator/readout/merge/transform_states.W"])
      blocks_trans_state = blocks_frame_probs_outputs["decoder_sequencegenerator_readout_merge__merge_apply_states"]
      assert blocks_trans_state.shape == (beam_size, last_lstm_state.shape[0])
      assert_almost_equal(blocks_trans_state[0], readout_in_state, decimal=4)
      readout_in_feedback = numpy.dot(last_target_embed, blocks_params["decoder/sequencegenerator/readout/merge/transform_feedback.W"])
      blocks_trans_feedback = blocks_frame_probs_outputs["decoder_sequencegenerator_readout_merge__merge_apply_feedback"]
      assert blocks_trans_feedback.shape == (beam_size, readout_in_feedback.shape[0])
      assert_almost_equal(blocks_trans_feedback[0], readout_in_feedback, decimal=4)
      readout_in_weighted_avg = numpy.dot(weighted_avg, blocks_params["decoder/sequencegenerator/readout/merge/transform_weighted_averages.W"])
      blocks_trans_weighted_avg = blocks_frame_probs_outputs["decoder_sequencegenerator_readout_merge__merge_apply_weighted_averages"]
      assert blocks_trans_weighted_avg.shape == (beam_size, readout_in_weighted_avg.shape[0])
      assert_almost_equal(blocks_trans_weighted_avg[0], readout_in_weighted_avg, decimal=4)
      readout_in = readout_in_state + readout_in_feedback + readout_in_weighted_avg
      blocks_readout_in = blocks_frame_probs_outputs["decoder_sequencegenerator_readout_merge__merge_apply_output"]
      assert blocks_readout_in.shape == (beam_size, readout_in.shape[0])
      assert_almost_equal(blocks_readout_in[0], readout_in, decimal=4)
      readout_in += blocks_params["decoder/sequencegenerator/readout/initializablefeedforwardsequence/maxout_bias.b"]
      assert readout_in.shape == (blocks_params["decoder/sequencegenerator/readout/initializablefeedforwardsequence/maxout_bias.b"].shape[0],)
      our_readout_in = our_dec_frame_outputs["readout_in.output"]
      assert our_readout_in.shape == (beam_size, readout_in.shape[0])
      assert_almost_equal(our_readout_in[0], readout_in, decimal=4)
      readout = readout_in.reshape((readout_in.shape[0] // 2, 2)).max(axis=1)
      our_readout = our_dec_frame_outputs["readout.output"]
      assert our_readout.shape == (beam_size, readout.shape[0])
      assert_almost_equal(our_readout[0], readout, decimal=4)
      prob_logits = numpy.dot(readout, blocks_params["decoder/sequencegenerator/readout/initializablefeedforwardsequence/softmax1.W"]) + \
        blocks_params["decoder/sequencegenerator/readout/initializablefeedforwardsequence/softmax1.b"]
      assert prob_logits.ndim == 1
      blocks_prob_logits = blocks_frame_probs_outputs["decoder_sequencegenerator_readout__readout_readout_output_0"]
      assert blocks_prob_logits.shape == (beam_size, prob_logits.shape[0])
      assert_almost_equal(blocks_prob_logits[0], prob_logits, decimal=4)
      output_prob = softmax(prob_logits)
      log_output_prob = log_softmax(prob_logits)
      assert_almost_equal(numpy.log(output_prob), log_output_prob, decimal=4)
      our_output_prob = our_dec_frame_outputs["output_prob.output"]
      assert our_output_prob.shape == (beam_size, output_prob.shape[0])
      assert_almost_equal(our_output_prob[0], output_prob, decimal=4)
      blocks_nlog_prob = blocks_frame_probs_outputs["logprobs"]
      assert blocks_nlog_prob.shape == (beam_size, output_prob.shape[0])
      assert_almost_equal(blocks_nlog_prob[0], -log_output_prob, decimal=4)
      assert_almost_equal(our_dec_search_frame_outputs["scores_in_orig"][0], output_prob, decimal=4)
      assert_almost_equal(blocks_search_frame[b'logprobs'][0], -log_output_prob, decimal=4)
      #for b in range(beam_size):
      #  assert_almost_equal(-numpy.log(our_output_prob[b]), blocks_frame_probs_outputs["logprobs"][b], decimal=4)
      ref_output = numpy.argmax(output_prob)
      # Note: Don't take the readout.emit outputs. They are randomly sampled.
      blocks_dec_output = blocks_search_frame[b'outputs']
      assert blocks_dec_output.shape == (beam_size,)
      our_dec_output = our_dec_frame_outputs["output.output"]
      assert our_dec_output.shape == (beam_size,)
      print("Frame %i: Ref best greedy output symbol: %i" % (dec_step, int(ref_output)))
      print("Blocks labels:", blocks_dec_output.tolist())
      print("Our labels:", our_dec_output.tolist())
      # Well, the following two could be not true if all the other beams have much better scores,
      # but this is unlikely.
      assert ref_output in blocks_dec_output
      assert ref_output in our_dec_output
      if dec_step == 0:
        # This assumes that the results are ordered by score which might not be true (see tf.nn.top_k).
        assert blocks_dec_output[0] == our_dec_output[0] == ref_output
      # We assume that the best is the same. Note that this also might not be true if there are two equally best scores.
      # It also assumes that it's ordered by the score which also might not be true (see tf.nn.top_k).
      # For the same reason, the remaining list and entries might also not perfectly match.
      assert our_dec_output[0] == blocks_dec_output[0]
      # Just follow the first beam.
      ref_output = blocks_dec_output[0]
      assert our_dec_search_frame_outputs["src_beam_idxs"].shape == (1, beam_size)
      assert our_dec_search_frame_outputs["scores"].shape == (1, beam_size)
      print("Blocks src_beam_idxs:", blocks_search_frame[b'indexes'].tolist())
      print("Our src_beam_idxs:", our_dec_search_frame_outputs["src_beam_idxs"][0].tolist())
      print("Blocks scores:", blocks_search_frame[b'chosen_costs'].tolist())
      print("Our scores:", our_dec_search_frame_outputs["scores"][0].tolist())
      if list(our_dec_search_frame_outputs["src_beam_idxs"][0]) != list(blocks_search_frame[b'indexes']):
        print("Warning, beams do not match.")
        print("Blocks scores base:", blocks_search_frame[b'scores_base'].flatten().tolist())
        print("Our scores base:", our_dec_search_frame_outputs["scores_base"].flatten().tolist())
        #print("Blocks score in orig top k:", sorted(blocks_search_frame[b'logprobs'].flatten())[:beam_size])
        #print("Our score in orig top k:", sorted(-numpy.log(our_dec_search_frame_outputs["scores_in_orig"].flatten()))[:beam_size])
        print("Blocks score in top k:", sorted((blocks_search_frame[b'logprobs'] * blocks_search_log[dec_step - 1][b'mask'][:, None]).flatten())[:beam_size])
        print("Our score in top k:", sorted(-our_dec_search_frame_outputs["scores_in"].flatten())[:beam_size])
        blocks_scores_combined = blocks_search_frame[b'next_costs']
        our_scores_combined = our_dec_search_frame_outputs["scores_combined"]
        print("Blocks scores combined top k:", sorted(blocks_scores_combined.flatten())[:beam_size])
        print("Our neg scores combined top k:", sorted(-our_scores_combined.flatten())[:beam_size])
        #raise Exception("beams mismatch")
      assert our_dec_search_frame_outputs["src_beam_idxs"][0][0] == blocks_search_frame[b'indexes'][0]
      beam_idx = our_dec_search_frame_outputs["src_beam_idxs"][0][0]
      if beam_idx != 0:
        print("Selecting different beam: %i." % beam_idx)
        # Just overwrite the needed states by Blocks outputs.
        accumulated_weights = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_attention__attention_take_glimpses_accumulated_weights"][0]
        weighted_avg = blocks_frame_state_outputs["decoder_sequencegenerator__sequencegenerator_generate_weighted_averages"][0]
        last_lstm_state = blocks_frame_state_outputs["decoder_sequencegenerator__sequencegenerator_generate_states"][0]
        last_lstm_cells = blocks_frame_state_outputs["decoder_sequencegenerator__sequencegenerator_generate_cells"][0]

      # From now on, use blocks_frame_state_outputs instead of blocks_frame_probs_outputs because
      # it will have the beam reordered.
      blocks_target_emb = blocks_frame_state_outputs["decoder_sequencegenerator_fork__fork_apply_feedback_decoder_input"]
      assert blocks_target_emb.shape == (beam_size, dec_lookup.shape[1])
      target_embed = dec_lookup[ref_output]
      assert target_embed.shape == (dec_lookup.shape[1],)
      assert_almost_equal(blocks_target_emb[0], target_embed)

      feedback_to_decoder = numpy.dot(target_embed, blocks_params["decoder/sequencegenerator/att_trans/feedback_to_decoder/fork_inputs.W"])
      context_to_decoder = numpy.dot(weighted_avg, blocks_params["decoder/sequencegenerator/att_trans/context_to_decoder/fork_inputs.W"])
      lstm_z = feedback_to_decoder + context_to_decoder
      assert lstm_z.shape == feedback_to_decoder.shape == context_to_decoder.shape == (last_lstm_state.shape[-1] * 4,)
      blocks_feedback_to_decoder = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_feedback_to_decoder__feedback_to_decoder_apply_inputs"]
      blocks_context_to_decoder = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_context_to_decoder__context_to_decoder_apply_inputs"]
      assert blocks_feedback_to_decoder.shape == blocks_context_to_decoder.shape == (beam_size, last_lstm_state.shape[-1] * 4)
      assert_almost_equal(blocks_feedback_to_decoder[0], feedback_to_decoder, decimal=4)
      assert_almost_equal(blocks_context_to_decoder[0], context_to_decoder, decimal=4)
      lstm_state, lstm_cells = calc_raw_lstm(
        lstm_z, blocks_params=blocks_params,
        prefix="decoder/sequencegenerator/att_trans/lstm_decoder.",
        last_state=last_lstm_state, last_cell=last_lstm_cells)
      assert lstm_state.shape == last_lstm_state.shape == lstm_cells.shape == last_lstm_cells.shape
      blocks_lstm_state = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_lstm_decoder__lstm_decoder_apply_states"]
      blocks_lstm_cells = blocks_frame_state_outputs["decoder_sequencegenerator_att_trans_lstm_decoder__lstm_decoder_apply_cells"]
      assert blocks_lstm_state.shape == blocks_lstm_cells.shape == (beam_size, last_lstm_state.shape[-1])
      assert_almost_equal(blocks_lstm_state[0], lstm_state, decimal=4)
      assert_almost_equal(blocks_lstm_cells[0], lstm_cells, decimal=4)
      our_lstm_cells = our_dec_frame_outputs["s.extra.state"][0]
      our_lstm_state = our_dec_frame_outputs["s.extra.state"][1]
      assert our_lstm_state.shape == our_lstm_cells.shape == (beam_size, lstm_state.shape[0])
      assert_almost_equal(our_lstm_state[0], lstm_state, decimal=4)
      assert_almost_equal(our_lstm_cells[0], lstm_cells, decimal=4)
      our_s = our_dec_frame_outputs["s.output"]
      assert our_s.shape == (beam_size, lstm_state.shape[0])
      assert_almost_equal(our_s[0], lstm_state, decimal=4)

      last_accumulated_weights = accumulated_weights
      last_lstm_state = lstm_state
      last_lstm_cells = lstm_cells
      last_output = ref_output
      if last_output == 0:
        print("Sequence finished, seq len %i." % dec_step)
        dec_seq_len = dec_step
        break
    assert dec_seq_len > 0
    print("All outputs seem to match.")
  else:
    print("blocks_debug_dump_output not specified. It will not compare the model outputs." % blocks_debug_dump_output)

  if dry_run:
    print("Dry-run, not saving model.")
  else:
    rnn.engine.save_model(our_model_fn)
  print("Finished importing.")