コード例 #1
0
ファイル: EngineBase.py プロジェクト: rwth-i6/returnn
  def get_epoch_model(cls, config):
    """
    :type config: Config.Config
    :returns (epoch, modelFilename)
    :rtype: (int|None, str|None)
    """
    # XXX: We cache it, although this is wrong if we have changed the config.
    if cls._epoch_model:
      return cls._epoch_model

    start_epoch_mode = config.value('start_epoch', 'auto')
    if start_epoch_mode == 'auto':
      start_epoch = None
    else:
      start_epoch = int(start_epoch_mode)
      assert start_epoch >= 1

    load_model_epoch_filename = config.value('load', '')
    if load_model_epoch_filename:
      assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())

    import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
    if import_model_train_epoch1:
      assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())

    existing_models = cls.get_existing_models(config)
    if not load_model_epoch_filename:
      if config.has("load_epoch"):
        load_epoch = config.int("load_epoch", 0)
        assert load_epoch in existing_models
        load_model_epoch_filename = existing_models[load_epoch]
        assert model_epoch_from_filename(load_model_epoch_filename) == load_epoch

    # Only use this when we don't train.
    # For training, we first consider existing models before we take the 'load' into account when in auto epoch mode.
    # In all other cases, we use the model specified by 'load'.
    if load_model_epoch_filename and (config.value('task', 'train') != 'train' or start_epoch is not None):
      epoch = model_epoch_from_filename(load_model_epoch_filename)
      if config.value('task', 'train') == 'train' and start_epoch is not None:
        # Ignore the epoch. To keep it consistent with the case below.
        epoch = None
      epoch_model = (epoch, load_model_epoch_filename)

    # In case of training, always first consider existing models.
    # This is because we reran RETURNN training, we usually don't want to train from scratch
    # but resume where we stopped last time.
    elif existing_models:
      epoch_model = sorted(existing_models.items())[-1]
      if load_model_epoch_filename:
        print("note: there is a 'load' which we ignore because of existing model", file=log.v4)

    elif config.value('task', 'train') == 'train' and import_model_train_epoch1 and start_epoch in [None, 1]:
      epoch_model = (0, import_model_train_epoch1)

    # Now, consider this also in the case when we train, as an initial model import.
    elif load_model_epoch_filename:
      # Don't use the model epoch as the start epoch in training.
      # We use this as an import for training.
      epoch_model = (model_epoch_from_filename(load_model_epoch_filename), load_model_epoch_filename)

    else:
      epoch_model = (None, None)

    if start_epoch == 1:
      if epoch_model[0]:  # existing model
        print("warning: there is an existing model: %s" % (epoch_model,), file=log.v4)
        epoch_model = (None, None)
    elif (start_epoch or 0) > 1:
      if epoch_model[0]:
        if epoch_model[0] != start_epoch - 1:
          print("warning: start_epoch %i but there is %s" % (start_epoch, epoch_model), file=log.v4)
        epoch_model = start_epoch - 1, existing_models[start_epoch - 1]

    cls._epoch_model = epoch_model
    return epoch_model
コード例 #2
0
    def get_epoch_model(cls, config):
        """
    :type config: Config.Config
    :returns (epoch, modelFilename)
    :rtype: (int|None, str|None)
    """
        # XXX: We cache it, although this is wrong if we have changed the config.
        if cls._epoch_model:
            return cls._epoch_model

        start_epoch_mode = config.value('start_epoch', 'auto')
        if start_epoch_mode == 'auto':
            start_epoch = None
        else:
            start_epoch = int(start_epoch_mode)
            assert start_epoch >= 1

        load_model_epoch_filename = config.value('load', '')
        if load_model_epoch_filename:
            fn_postfix = ""
            if BackendEngine.is_tensorflow_selected():
                fn_postfix = ".meta"
            assert os.path.exists(load_model_epoch_filename + fn_postfix)

        import_model_train_epoch1 = config.value('import_model_train_epoch1',
                                                 '')
        if import_model_train_epoch1:
            assert os.path.exists(import_model_train_epoch1)

        existing_models = cls.get_existing_models(config)

        # Only use this when we don't train.
        # For training, we first consider existing models before we take the 'load' into account when in auto epoch mode.
        # In all other cases, we use the model specified by 'load'.
        if load_model_epoch_filename and (config.value('task', 'train') !=
                                          'train' or start_epoch is not None):
            epoch = model_epoch_from_filename(load_model_epoch_filename)
            if config.value('task',
                            'train') == 'train' and start_epoch is not None:
                # Ignore the epoch. To keep it consistent with the case below.
                epoch = None
            epoch_model = (epoch, load_model_epoch_filename)

        # In case of training, always first consider existing models.
        # This is because we reran CRNN training, we usually don't want to train from scratch
        # but resume where we stopped last time.
        elif existing_models:
            epoch_model = existing_models[-1]
            if load_model_epoch_filename:
                print >> log.v4, "note: there is a 'load' which we ignore because of existing model"

        elif config.value(
                'task', 'train'
        ) == 'train' and import_model_train_epoch1 and start_epoch in [
                None, 1
        ]:
            epoch_model = (0, import_model_train_epoch1)

        # Now, consider this also in the case when we train, as an initial model import.
        elif load_model_epoch_filename:
            # Don't use the model epoch as the start epoch in training.
            # We use this as an import for training.
            epoch_model = (
                model_epoch_from_filename(load_model_epoch_filename),
                load_model_epoch_filename)

        else:
            epoch_model = (None, None)

        if start_epoch == 1:
            if epoch_model[0]:  # existing model
                print >> log.v4, "warning: there is an existing model: %s" % (
                    epoch_model, )
                epoch_model = (None, None)
        elif start_epoch > 1:
            if epoch_model[0]:
                if epoch_model[0] != start_epoch - 1:
                    print >> log.v4, "warning: start_epoch %i but there is %s" % (
                        start_epoch, epoch_model)
                epoch_model = existing_models[start_epoch - 1]

        cls._epoch_model = epoch_model
        return epoch_model
コード例 #3
0
ファイル: EngineBase.py プロジェクト: wj199031738/returnn
  def get_epoch_model(cls, config):
    """
    :type config: Config.Config
    :returns (epoch, modelFilename)
    :rtype: (int|None, str|None)
    """
    start_epoch_mode = config.value('start_epoch', 'auto')
    if start_epoch_mode == 'auto':
      start_epoch = None
    else:
      start_epoch = int(start_epoch_mode)
      assert start_epoch >= 1

    load_model_epoch_filename = config.value('load', '')
    if load_model_epoch_filename.endswith(".meta"):
      load_model_epoch_filename = load_model_epoch_filename[:-len(".meta")]
    if load_model_epoch_filename:
      assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())

    import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
    if import_model_train_epoch1.endswith(".meta"):
      import_model_train_epoch1 = import_model_train_epoch1[:-len(".meta")]
    if import_model_train_epoch1:
      assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())

    existing_models = cls.get_existing_models(config)
    if not load_model_epoch_filename:
      load_epoch = config.int("load_epoch", -1)
      if load_epoch > 0:  # ignore if load_epoch == 0
        assert load_epoch in existing_models
        load_model_epoch_filename = existing_models[load_epoch]
        assert model_epoch_from_filename(load_model_epoch_filename) == load_epoch

    # Only use this when we don't train.
    # For training, we first consider existing models before we take the 'load' into account when in auto epoch mode.
    # In all other cases, we use the model specified by 'load'.
    if load_model_epoch_filename and (config.value('task', 'train') != 'train' or start_epoch is not None):
      epoch = model_epoch_from_filename(load_model_epoch_filename)
      if config.value('task', 'train') == 'train' and start_epoch is not None:
        # Ignore the epoch. To keep it consistent with the case below.
        epoch = None
      epoch_model = (epoch, load_model_epoch_filename)

    # In case of training, always first consider existing models.
    # This is because we reran RETURNN training, we usually don't want to train from scratch
    # but resume where we stopped last time.
    elif existing_models:
      epoch_model = sorted(existing_models.items())[-1]
      if load_model_epoch_filename:
        print("note: there is a 'load' which we ignore because of existing model", file=log.v4)

    elif config.value('task', 'train') == 'train' and import_model_train_epoch1 and start_epoch in [None, 1]:
      epoch_model = (0, import_model_train_epoch1)

    # Now, consider this also in the case when we train, as an initial model import.
    elif load_model_epoch_filename:
      # Don't use the model epoch as the start epoch in training.
      # We use this as an import for training.
      epoch_model = (model_epoch_from_filename(load_model_epoch_filename), load_model_epoch_filename)

    else:
      epoch_model = (None, None)

    if start_epoch == 1:
      if epoch_model[0]:  # existing model
        print("warning: there is an existing model: %s" % (epoch_model,), file=log.v4)
        epoch_model = (None, None)
    elif (start_epoch or 0) > 1:
      if epoch_model[0]:
        if epoch_model[0] != start_epoch - 1:
          print("warning: start_epoch %i but there is %s" % (start_epoch, epoch_model), file=log.v4)
        epoch_model = start_epoch - 1, existing_models[start_epoch - 1]

    return epoch_model