Example #1
0
  def get_epoch_model(cls, config):
    """
    :type config: Config.Config
    :returns (epoch, modelFilename)
    :rtype: (int|None, str|None)
    """
    # XXX: We cache it, although this is wrong if we have changed the config.
    if cls._epoch_model:
      return cls._epoch_model

    start_epoch_mode = config.value('start_epoch', 'auto')
    if start_epoch_mode == 'auto':
      start_epoch = None
    else:
      start_epoch = int(start_epoch_mode)
      assert start_epoch >= 1

    load_model_epoch_filename = config.value('load', '')
    if load_model_epoch_filename:
      assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())

    import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
    if import_model_train_epoch1:
      assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())

    existing_models = cls.get_existing_models(config)
    if not load_model_epoch_filename:
      if config.has("load_epoch"):
        load_epoch = config.int("load_epoch", 0)
        assert load_epoch in existing_models
        load_model_epoch_filename = existing_models[load_epoch]
        assert model_epoch_from_filename(load_model_epoch_filename) == load_epoch

    # Only use this when we don't train.
    # For training, we first consider existing models before we take the 'load' into account when in auto epoch mode.
    # In all other cases, we use the model specified by 'load'.
    if load_model_epoch_filename and (config.value('task', 'train') != 'train' or start_epoch is not None):
      epoch = model_epoch_from_filename(load_model_epoch_filename)
      if config.value('task', 'train') == 'train' and start_epoch is not None:
        # Ignore the epoch. To keep it consistent with the case below.
        epoch = None
      epoch_model = (epoch, load_model_epoch_filename)

    # In case of training, always first consider existing models.
    # This is because we reran RETURNN training, we usually don't want to train from scratch
    # but resume where we stopped last time.
    elif existing_models:
      epoch_model = sorted(existing_models.items())[-1]
      if load_model_epoch_filename:
        print("note: there is a 'load' which we ignore because of existing model", file=log.v4)

    elif config.value('task', 'train') == 'train' and import_model_train_epoch1 and start_epoch in [None, 1]:
      epoch_model = (0, import_model_train_epoch1)

    # Now, consider this also in the case when we train, as an initial model import.
    elif load_model_epoch_filename:
      # Don't use the model epoch as the start epoch in training.
      # We use this as an import for training.
      epoch_model = (model_epoch_from_filename(load_model_epoch_filename), load_model_epoch_filename)

    else:
      epoch_model = (None, None)

    if start_epoch == 1:
      if epoch_model[0]:  # existing model
        print("warning: there is an existing model: %s" % (epoch_model,), file=log.v4)
        epoch_model = (None, None)
    elif (start_epoch or 0) > 1:
      if epoch_model[0]:
        if epoch_model[0] != start_epoch - 1:
          print("warning: start_epoch %i but there is %s" % (start_epoch, epoch_model), file=log.v4)
        epoch_model = start_epoch - 1, existing_models[start_epoch - 1]

    cls._epoch_model = epoch_model
    return epoch_model
Example #2
0
  def get_epoch_model(cls, config):
    """
    :type config: Config.Config
    :returns (epoch, modelFilename)
    :rtype: (int|None, str|None)
    """
    # XXX: We cache it, although this is wrong if we have changed the config.
    if cls._epoch_model:
      return cls._epoch_model

    start_epoch_mode = config.value('start_epoch', 'auto')
    if start_epoch_mode == 'auto':
      start_epoch = None
    else:
      start_epoch = int(start_epoch_mode)
      assert start_epoch >= 1

    load_model_epoch_filename = config.value('load', '')
    if load_model_epoch_filename:
      assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())

    import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
    if import_model_train_epoch1:
      assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())

    existing_models = cls.get_existing_models(config)
    if not load_model_epoch_filename:
      if config.has("load_epoch"):
        load_epoch = config.int("load_epoch", 0)
        assert load_epoch in existing_models
        load_model_epoch_filename = existing_models[load_epoch]
        assert model_epoch_from_filename(load_model_epoch_filename) == load_epoch

    # Only use this when we don't train.
    # For training, we first consider existing models before we take the 'load' into account when in auto epoch mode.
    # In all other cases, we use the model specified by 'load'.
    if load_model_epoch_filename and (config.value('task', 'train') != 'train' or start_epoch is not None):
      epoch = model_epoch_from_filename(load_model_epoch_filename)
      if config.value('task', 'train') == 'train' and start_epoch is not None:
        # Ignore the epoch. To keep it consistent with the case below.
        epoch = None
      epoch_model = (epoch, load_model_epoch_filename)

    # In case of training, always first consider existing models.
    # This is because we reran CRNN training, we usually don't want to train from scratch
    # but resume where we stopped last time.
    elif existing_models:
      epoch_model = sorted(existing_models.items())[-1]
      if load_model_epoch_filename:
        print("note: there is a 'load' which we ignore because of existing model", file=log.v4)

    elif config.value('task', 'train') == 'train' and import_model_train_epoch1 and start_epoch in [None, 1]:
      epoch_model = (0, import_model_train_epoch1)

    # Now, consider this also in the case when we train, as an initial model import.
    elif load_model_epoch_filename:
      # Don't use the model epoch as the start epoch in training.
      # We use this as an import for training.
      epoch_model = (model_epoch_from_filename(load_model_epoch_filename), load_model_epoch_filename)

    else:
      epoch_model = (None, None)

    if start_epoch == 1:
      if epoch_model[0]:  # existing model
        print("warning: there is an existing model: %s" % (epoch_model,), file=log.v4)
        epoch_model = (None, None)
    elif (start_epoch or 0) > 1:
      if epoch_model[0]:
        if epoch_model[0] != start_epoch - 1:
          print("warning: start_epoch %i but there is %s" % (start_epoch, epoch_model), file=log.v4)
        epoch_model = start_epoch - 1, existing_models[start_epoch - 1]

    cls._epoch_model = epoch_model
    return epoch_model