Ejemplo n.º 1
0
def make_lr_fn(base_learning_rate, steps_per_epoch):
    if FLAGS.lr_schedule == 'constant':
        return lr_schedule.create_constant_learning_rate_schedule(
            base_learning_rate, steps_per_epoch)

    if FLAGS.lr_schedule == 'stepped':
        if not FLAGS.lr_sched_steps:
            lr_sched_steps = [[60, 0.2], [120, 0.04], [160, 0.008]]
        else:
            lr_sched_steps = ast.literal_eval(FLAGS.lr_sched_steps)

        return lr_schedule.create_stepped_learning_rate_schedule(
            base_learning_rate, steps_per_epoch, lr_sched_steps)

    if FLAGS.lr_schedule == 'cosine':
        return lr_schedule.create_cosine_learning_rate_schedule(
            base_learning_rate, steps_per_epoch, FLAGS.num_epochs)
Ejemplo n.º 2
0
        learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
  elif FLAGS.optimizer == 'Momentum':
    optimizer = flax.optim.Momentum(
        learning_rate=FLAGS.lr,
        beta=FLAGS.momentum,
        weight_decay=FLAGS.weight_decay,
        nesterov=False)

  steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size

  if FLAGS.lr_schedule == LR_SCHEDULE_CONSTANT:
    lr_fn = lr_schedule.create_constant_learning_rate_schedule(
        FLAGS.lr, steps_per_epoch)
  elif FLAGS.lr_schedule == LR_SCHEDULE_STEPPED:
    lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
    lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
        FLAGS.lr, steps_per_epoch, lr_schedule_steps)
  elif FLAGS.lr_schedule == LR_SCHEDULE_COSINE:
    lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
        FLAGS.lr, steps_per_epoch, FLAGS.epochs)
  else:
    raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}')

  # Reuses the FLAX learning rate schedule framework for pruning rate schedule.
  pruning_fn_p = functools.partial(
      lr_schedule.create_stepped_learning_rate_schedule, FLAGS.pruning_rate,
      steps_per_epoch)
  if FLAGS.pruning_schedule:
    pruning_schedule = ast.literal_eval(FLAGS.pruning_schedule)
    if isinstance(pruning_schedule, collections.Mapping):
      pruning_rate_fn = {
          f'MaskedModule_{layer_num}': pruning_fn_p(schedule)
Ejemplo n.º 3
0
def run_training():
    """Trains a model."""
    print('Logging to {}'.format(FLAGS.log_dir))
    work_unit_id = uuid.uuid4()
    experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))

    logging.info('Saving experimental results to %s', experiment_dir)

    host_count = jax.host_count()
    local_device_count = jax.local_device_count()
    logging.info('Device count: %d, host count: %d, local device count: %d',
                 jax.device_count(), host_count, local_device_count)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(experiment_dir)

    dataset = dataset_factory.create_dataset(
        FLAGS.dataset,
        FLAGS.batch_size,
        FLAGS.batch_size_test,
        shuffle_buffer_size=FLAGS.shuffle_buffer_size)

    logging.info('Training %s on the %s dataset...', FLAGS.model,
                 FLAGS.dataset)

    rng = jax.random.PRNGKey(FLAGS.random_seed)

    input_shape = (1, ) + dataset.shape
    model, initial_state = model_factory.create_model(
        FLAGS.model,
        rng, ((input_shape, np.float32), ),
        num_classes=dataset.num_classes)

    if FLAGS.optimizer == 'Adam':
        optimizer = flax.optim.Adam(learning_rate=FLAGS.lr,
                                    weight_decay=FLAGS.weight_decay)
    elif FLAGS.optimizer == 'Momentum':
        optimizer = flax.optim.Momentum(learning_rate=FLAGS.lr,
                                        beta=FLAGS.momentum,
                                        weight_decay=FLAGS.weight_decay,
                                        nesterov=False)

    steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size

    if FLAGS.lr_schedule == 'constant':
        lr_fn = lr_schedule.create_constant_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch)
    elif FLAGS.lr_schedule == 'stepped':
        lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
        lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch, lr_schedule_steps)
    elif FLAGS.lr_schedule == 'cosine':
        lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch, FLAGS.epochs)
    else:
        raise ValueError('Unknown LR schedule type {}'.format(
            FLAGS.lr_schedule))

    if jax.host_id() == 0:
        trainer = training.Trainer(
            optimizer,
            model,
            initial_state,
            dataset,
            rng,
            summary_writer=summary_writer,
        )
    else:
        trainer = training.Trainer(optimizer, model, initial_state, dataset,
                                   rng)

    _, best_metrics = trainer.train(FLAGS.epochs,
                                    lr_fn=lr_fn,
                                    update_iter=FLAGS.update_iterations,
                                    update_epoch=FLAGS.update_epoch)

    logging.info('Best metrics: %s', str(best_metrics))

    if jax.host_id() == 0:
        for label, value in best_metrics.items():
            summary_writer.scalar('best/{}'.format(label), value,
                                  FLAGS.epochs * steps_per_epoch)
        summary_writer.close()
Ejemplo n.º 4
0
 def make_clf_lr_fun(base_lr, steps_per_epoch):  # pylint: disable=function-redefined
     return lr_schedule.create_stepped_learning_rate_schedule(
         base_lr, steps_per_epoch, [[60, 0.2], [75, 0.04], [90, 0.008]])
Ejemplo n.º 5
0
 def make_moco_lr_fun(base_lr, steps_per_epoch):  # pylint: disable=function-redefined
     return lr_schedule.create_stepped_learning_rate_schedule(
         base_lr, steps_per_epoch, [[120, 0.1], [160, 0.01]])
Ejemplo n.º 6
0
 def make_lr_fun(base_lr, steps_per_epoch):
     return lr_schedule.create_stepped_learning_rate_schedule(
         base_lr, steps_per_epoch, lr_sched_steps)
Ejemplo n.º 7
0
 def make_clf_lr_fun(base_lr, steps_per_epoch):
     return lr_schedule.create_stepped_learning_rate_schedule(
         base_lr,
         steps_per_epoch,
         lr_clf_sched_steps,
         warmup_length=FLAGS.lr_clf_sched_warmup)
Ejemplo n.º 8
0
 def make_lr_fn(base_lr, steps_per_epoch):
   return lr_schedule.create_stepped_learning_rate_schedule(
       base_lr,
       steps_per_epoch,
       lr_sched_steps,
       warmup_length=config.warmup_epochs)
Ejemplo n.º 9
0
 def make_step_size_fn(steps_per_epoch):
   return lr_schedule.create_stepped_learning_rate_schedule(
       1.0,
       steps_per_epoch,
       lr_sched_steps,
       warmup_length=config.warmup_epochs)