Exemple #1
0
def add_trainer(
    experiment: cfg.ExperimentConfig,
    train_batch_size: int,
    eval_batch_size: int,
    learning_rate: float = 0.0001,
    train_epochs: int = 50,
    num_train_examples: int = YT8M_TRAIN_EXAMPLES,
    num_val_examples: int = YT8M_VAL_EXAMPLES,
):
    """Add and config a trainer to the experiment config."""
    if num_train_examples <= 0:
        raise ValueError('Wrong train dataset size {!r}'.format(
            experiment.task.train_data))
    if num_val_examples <= 0:
        raise ValueError('Wrong validation dataset size {!r}'.format(
            experiment.task.validation_data))
    experiment.task.train_data.global_batch_size = train_batch_size
    experiment.task.validation_data.global_batch_size = eval_batch_size
    steps_per_epoch = num_train_examples // train_batch_size
    steps_per_loop = 500
    experiment.trainer = cfg.TrainerConfig(
        steps_per_loop=steps_per_loop,
        summary_interval=steps_per_loop,
        checkpoint_interval=steps_per_loop,
        train_steps=train_epochs * steps_per_epoch,
        validation_steps=num_val_examples // eval_batch_size,
        validation_interval=steps_per_loop,
        optimizer_config=optimization.OptimizationConfig({
            'optimizer': {
                'type': 'adam',
                'adam': {}
            },
            'learning_rate': {
                'type': 'exponential',
                'exponential': {
                    'initial_learning_rate': learning_rate,
                    'decay_rate': 0.95,
                    'decay_steps': int(steps_per_epoch * 1.5),
                    'offset': 500,
                }
            },
            'warmup': {
                'linear': {
                    'name': 'linear',
                    'warmup_learning_rate': 0,
                    'warmup_steps': 500,
                },
                'type': 'linear',
            }
        }))
    return experiment
def add_trainer(experiment: cfg.ExperimentConfig,
                train_batch_size: int,
                eval_batch_size: int,
                learning_rate: float = 1.6,
                train_epochs: int = 44,
                warmup_epochs: int = 5):
  """Add and config a trainer to the experiment config."""
  if experiment.task.train_data.num_examples <= 0:
    raise ValueError('Wrong train dataset size {!r}'.format(
        experiment.task.train_data))
  if experiment.task.validation_data.num_examples <= 0:
    raise ValueError('Wrong validation dataset size {!r}'.format(
        experiment.task.validation_data))
  experiment.task.train_data.global_batch_size = train_batch_size
  experiment.task.validation_data.global_batch_size = eval_batch_size
  steps_per_epoch = experiment.task.train_data.num_examples // train_batch_size
  experiment.trainer = cfg.TrainerConfig(
      steps_per_loop=steps_per_epoch,
      summary_interval=steps_per_epoch,
      checkpoint_interval=steps_per_epoch,
      train_steps=train_epochs * steps_per_epoch,
      validation_steps=experiment.task.validation_data.num_examples //
      eval_batch_size,
      validation_interval=steps_per_epoch,
      optimizer_config=optimization.OptimizationConfig({
          'optimizer': {
              'type': 'sgd',
              'sgd': {
                  'momentum': 0.9,
                  'nesterov': True,
              }
          },
          'learning_rate': {
              'type': 'cosine',
              'cosine': {
                  'initial_learning_rate': learning_rate,
                  'decay_steps': train_epochs * steps_per_epoch,
              }
          },
          'warmup': {
              'type': 'linear',
              'linear': {
                  'warmup_steps': warmup_epochs * steps_per_epoch,
                  'warmup_learning_rate': 0
              }
          }
      }))
  return experiment
Exemple #3
0
def add_trainer(
    experiment: cfg.ExperimentConfig,
    train_batch_size: int,
    eval_batch_size: int,
    learning_rate: float = 0.005,
    train_epochs: int = 44,
):
    """Add and config a trainer to the experiment config."""
    if YT8M_TRAIN_EXAMPLES <= 0:
        raise ValueError('Wrong train dataset size {!r}'.format(
            experiment.task.train_data))
    if YT8M_VAL_EXAMPLES <= 0:
        raise ValueError('Wrong validation dataset size {!r}'.format(
            experiment.task.validation_data))
    experiment.task.train_data.global_batch_size = train_batch_size
    experiment.task.validation_data.global_batch_size = eval_batch_size
    steps_per_epoch = YT8M_TRAIN_EXAMPLES // train_batch_size
    experiment.trainer = cfg.TrainerConfig(
        steps_per_loop=steps_per_epoch,
        summary_interval=steps_per_epoch,
        checkpoint_interval=steps_per_epoch,
        train_steps=train_epochs * steps_per_epoch,
        validation_steps=YT8M_VAL_EXAMPLES // eval_batch_size,
        validation_interval=steps_per_epoch,
        optimizer_config=optimization.OptimizationConfig({
            'optimizer': {
                'type': 'adam',
                'adam': {}
            },
            'learning_rate': {
                'type': 'exponential',
                'exponential': {
                    'initial_learning_rate': learning_rate,
                    'decay_rate': 0.95,
                    'decay_steps': 1500000,
                }
            },
        }))
    return experiment