コード例 #1
0
def pretrainFromConfig(config):
    """
  :type config: Config.Config
  :rtype: Pretrain | None
  """
    pretrainType = config.bool_or_other("pretrain", None)
    if pretrainType == "default" or (isinstance(pretrainType, dict)
                                     and pretrainType) or pretrainType is True:
        network_init_args = LayerNetwork.init_args_from_config(config)
        original_network_json = LayerNetwork.json_from_config(config)
        opts = config.get_of_type("pretrain", dict, {})
        if config.has("pretrain_copy_output_layer"):
            opts.setdefault(
                "copy_output_layer",
                config.bool_or_other("pretrain_copy_output_layer",
                                     "ifpossible"))
        if config.has("pretrain_greedy"):
            opts.setdefault("greedy", config.bool("pretrain_greedy", None))
        if config.has("pretrain_repetitions"):
            if config.is_typed("pretrain_repetitions"):
                opts.setdefault("repetitions",
                                config.typed_value("pretrain_repetitions"))
            else:
                opts.setdefault("repetitions",
                                config.int_list("pretrain_repetitions", None))
        if config.has("pretrain_construction_algo"):
            opts.setdefault("construction_algo",
                            config.value("pretrain_construction_algo", None))
        return Pretrain(original_network_json=original_network_json,
                        network_init_args=network_init_args,
                        **opts)
    elif not pretrainType:
        return None
    else:
        raise Exception("unknown pretrain type: %s" % pretrainType)
コード例 #2
0
def pretrainFromConfig(config):
  """
  :type config: Config.Config
  :rtype: Pretrain | None
  """
  pretrainType = config.value("pretrain", "")
  if pretrainType == "default":
    network_init_args = LayerNetwork.init_args_from_config(config)
    original_network_json = LayerNetwork.json_from_config(config)
    copy_output_layer = config.bool_or_other("pretrain_copy_output_layer", "ifpossible")
    greedy = config.bool("pretrain_greedy", None)
    if config.is_typed("pretrain_repetitions"):
      repetitions = config.typed_value("pretrain_repetitions")
    else:
      repetitions = config.int_list("pretrain_repetitions", None)
    construction_algo = config.value("pretrain_construction_algo", None)
    return Pretrain(original_network_json=original_network_json,
                    network_init_args=network_init_args,
                    copy_output_layer=copy_output_layer,
                    greedy=greedy, repetitions=repetitions,
                    construction_algo=construction_algo)
  elif pretrainType == "":
    return None
  else:
    raise Exception("unknown pretrain type: %s" % pretrainType)
コード例 #3
0
  def init_network_from_config(self, config):
    """
    :param Config.Config config:
    """
    self.model_filename = config.value('model', None)
    self.pretrain = pretrainFromConfig(config)
    self.max_seqs = config.int('max_seqs', -1)

    epoch, model_epoch_filename = self.get_epoch_model(config)
    assert model_epoch_filename or self.start_epoch
    self.epoch = epoch or self.start_epoch

    if self.pretrain:
      # This would be obsolete if we don't want to load an existing model.
      # In self.init_train_epoch(), we initialize a new model.
      net_dict = self.pretrain.get_network_json_for_epoch(self.epoch)
    else:
      net_dict = LayerNetwork.json_from_config(config)

    self._init_network(net_desc=net_dict, epoch=self.epoch)

    if model_epoch_filename:
      print("loading weights from", model_epoch_filename, file=log.v2)
      try:
        self.network.load_params_from_file(model_epoch_filename, session=self.tf_session)
      except tf.errors.NotFoundError:
        print("Exiting now because model cannot be loaded.", file=log.v1)
        sys.exit(1)
コード例 #4
0
ファイル: test_Network.py プロジェクト: rwth-i6/returnn
def test_enc_dec1_init():
  config = Config()
  config.load_file(StringIO(config_enc_dec1_json))

  network_json = LayerNetwork.json_from_config(config)
  assert_true(network_json)
  network = LayerNetwork.from_json_and_config(network_json, config)
  assert_true(network)
コード例 #5
0
def test_enc_dec1_init():
    config = Config()
    config.load_file(StringIO(config_enc_dec1_json))

    network_json = LayerNetwork.json_from_config(config)
    assert_true(network_json)
    network = LayerNetwork.from_json_and_config(network_json, config)
    assert_true(network)
コード例 #6
0
    def init_network_from_config(self, config):
        self.pretrain = pretrainFromConfig(config)
        self.max_seqs = config.int('max_seqs', -1)

        epoch, model_epoch_filename = self.get_epoch_model(config)
        assert model_epoch_filename or self.start_epoch

        if self.pretrain:
            # This would be obsolete if we don't want to load an existing model.
            # In self.init_train_epoch(), we initialize a new model.
            net_dict = self.pretrain.get_network_json_for_epoch(
                epoch or self.start_epoch)
        else:
            net_dict = LayerNetwork.json_from_config(config)

        self._init_network(net_desc=net_dict, epoch=epoch or self.start_epoch)

        if model_epoch_filename:
            print("loading weights from", model_epoch_filename, file=log.v2)
            self.network.load_params_from_file(model_epoch_filename,
                                               session=self.tf_session)
コード例 #7
0
ファイル: test_Network.py プロジェクト: rwth-i6/returnn
def test_enc_dec1_hdf():
  filename = tempfile.mktemp(prefix="crnn-model-test")
  model = h5py.File(filename, "w")

  config = Config()
  config.load_file(StringIO(config_enc_dec1_json))
  network_json = LayerNetwork.json_from_config(config)
  assert_true(network_json)
  network = LayerNetwork.from_json_and_config(network_json, config)
  assert_true(network)

  network.save_hdf(model, epoch=42)
  model.close()

  loaded_model = h5py.File(filename, "r")
  loaded_net = LayerNetwork.from_hdf_model_topology(loaded_model)
  assert_true(loaded_net)
  assert_equal(sorted(network.hidden.keys()), sorted(loaded_net.hidden.keys()))
  assert_equal(sorted(network.y.keys()), sorted(loaded_net.y.keys()))
  assert_equal(sorted(network.j.keys()), sorted(loaded_net.j.keys()))

  os.remove(filename)
コード例 #8
0
ファイル: Pretrain.py プロジェクト: atuxhe/returnn
def pretrainFromConfig(config):
  """
  :type config: Config.Config
  :rtype: Pretrain | None
  """
  pretrainType = config.value("pretrain", "")
  if pretrainType == "default":
    network_init_args = LayerNetwork.init_args_from_config(config)
    original_network_json = LayerNetwork.json_from_config(config)
    copy_output_layer = config.bool_or_other("pretrain_copy_output_layer", "ifpossible")
    greedy = config.bool("pretrain_greedy", None)
    repetitions = config.int_list("pretrain_repetitions", None)
    construction_algo = config.value("pretrain_construction_algo", None)
    return Pretrain(original_network_json=original_network_json,
                    network_init_args=network_init_args,
                    copy_output_layer=copy_output_layer,
                    greedy=greedy, repetitions=repetitions,
                    construction_algo=construction_algo)
  elif pretrainType == "":
    return None
  else:
    raise Exception("unknown pretrain type: %s" % pretrainType)
コード例 #9
0
def test_enc_dec1_hdf():
    filename = tempfile.mktemp(prefix="crnn-model-test")
    model = h5py.File(filename, "w")

    config = Config()
    config.load_file(StringIO(config_enc_dec1_json))
    network_json = LayerNetwork.json_from_config(config)
    assert_true(network_json)
    network = LayerNetwork.from_json_and_config(network_json, config)
    assert_true(network)

    network.save_hdf(model, epoch=42)
    model.close()

    loaded_model = h5py.File(filename, "r")
    loaded_net = LayerNetwork.from_hdf_model_topology(loaded_model)
    assert_true(loaded_net)
    assert_equal(sorted(network.hidden.keys()),
                 sorted(loaded_net.hidden.keys()))
    assert_equal(sorted(network.y.keys()), sorted(loaded_net.y.keys()))
    assert_equal(sorted(network.j.keys()), sorted(loaded_net.j.keys()))

    os.remove(filename)