def make_lr_fn(base_learning_rate, steps_per_epoch): if FLAGS.lr_schedule == 'constant': return lr_schedule.create_constant_learning_rate_schedule( base_learning_rate, steps_per_epoch) if FLAGS.lr_schedule == 'stepped': if not FLAGS.lr_sched_steps: lr_sched_steps = [[60, 0.2], [120, 0.04], [160, 0.008]] else: lr_sched_steps = ast.literal_eval(FLAGS.lr_sched_steps) return lr_schedule.create_stepped_learning_rate_schedule( base_learning_rate, steps_per_epoch, lr_sched_steps) if FLAGS.lr_schedule == 'cosine': return lr_schedule.create_cosine_learning_rate_schedule( base_learning_rate, steps_per_epoch, FLAGS.num_epochs)
masked_layer_indices=FLAGS.masked_layer_indices) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam( learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum( learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == LR_SCHEDULE_CONSTANT: lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == LR_SCHEDULE_STEPPED: lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == LR_SCHEDULE_COSINE: lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}') # Reuses the FLAX learning rate schedule framework for pruning rate schedule. pruning_fn_p = functools.partial( lr_schedule.create_stepped_learning_rate_schedule, FLAGS.pruning_rate, steps_per_epoch) if FLAGS.pruning_schedule:
def run_training(): """Trains a model.""" print('Logging to {}'.format(FLAGS.log_dir)) work_unit_id = uuid.uuid4() experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1, ) + dataset.shape model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, np.float32), ), num_classes=dataset.num_classes) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam(learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum(learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == 'constant': lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == 'stepped': lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == 'cosine': lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError('Unknown LR schedule type {}'.format( FLAGS.lr_schedule)) if jax.host_id() == 0: trainer = training.Trainer( optimizer, model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer(optimizer, model, initial_state, dataset, rng) _, best_metrics = trainer.train(FLAGS.epochs, lr_fn=lr_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: for label, value in best_metrics.items(): summary_writer.scalar('best/{}'.format(label), value, FLAGS.epochs * steps_per_epoch) summary_writer.close()
def make_lr_fun(base_lr, steps_per_epoch): return lr_schedule.create_constant_learning_rate_schedule( base_lr, steps_per_epoch)
def make_lr_fn(base_lr, steps_per_epoch): return lr_schedule.create_constant_learning_rate_schedule( base_lr, steps_per_epoch, warmup_length=config.warmup_epochs)
def make_step_size_fn(steps_per_epoch): return lr_schedule.create_constant_learning_rate_schedule( 1.0, steps_per_epoch, warmup_length=config.warmup_epochs)