def build_optimizer(self, clip=15.0, lr=5e-4, warmup=2000, cosine_decay_steps=None, optimizer_name="adabelief") -> GradientTransformation: chain = [] if optimizer_name == "adabelief": chain.append(util.scale_by_belief()) elif optimizer_name == "adam": chain.append(optax.scale_by_adam()) else: assert 0 # Make sure to use the negative learning rate so that we minimize if warmup and warmup > 0: warmup_schedule = partial(util.linear_warmup_lr_schedule, warmup=warmup, lr_decay=1.0, lr=-lr) chain.append(optax.scale_by_schedule(warmup_schedule)) else: chain.append(optax.scale(-lr)) if cosine_decay_steps and cosine_decay_steps > 0: cosine_lr = optax.cosine_decay_schedule( init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1) chain.append(optax.scale_by_schedule(cosine_lr)) if clip and clip > 0: chain.append(optax.clip(clip)) return optax.chain(*chain)
def create_learning_rate_fn(config,): """Create learning rate schedule.""" # Linear warmup warmup_fn = optax.linear_schedule( init_value=0., end_value=config.train.lr_init, transition_steps=config.train.warmup_steps) if config.train.scheduler == "linear": decay_fn = optax.linear_schedule( init_value=config.train.lr_init, end_value=0., transition_steps=config.train.max_steps - config.train.warmup_steps) elif config.train.scheduler == "cosine": cosine_steps = max(config.train.max_steps - config.train.warmup_steps, 1) decay_fn = optax.cosine_decay_schedule( init_value=config.train.lr_init, decay_steps=cosine_steps) elif config.train.scheduler == "step": step_steps = max(config.train.max_steps - config.train.warmup_steps, 1) # pylint: disable=unused-variable def schedule(count): return config.train.lr_init * (0.5**(count // 50000)) decay_fn = schedule else: raise NotImplementedError schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[config.train.warmup_steps]) return schedule_fn
def get_learning_rate_schedule( total_batch_size, steps_per_epoch, total_steps, optimizer_config): """Build the learning rate schedule function.""" base_lr = _get_batch_scaled_lr(total_batch_size, optimizer_config.base_lr, optimizer_config.scale_by_batch) schedule_type = optimizer_config.schedule_type if schedule_type == 'steps': boundaries = optimizer_config.step_decay_kwargs.decay_boundaries boundaries.sort() decay_rate = optimizer_config.step_decay_kwargs.decay_rate boundaries_and_scales = { int(boundary * total_steps): decay_rate for boundary in boundaries} schedule_fn = optax.piecewise_constant_schedule( init_value=base_lr, boundaries_and_scales=boundaries_and_scales) elif schedule_type == 'cosine': warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_epochs * steps_per_epoch) # Batch scale the other lr values as well: init_value = _get_batch_scaled_lr( total_batch_size, optimizer_config.cosine_decay_kwargs.init_value, optimizer_config.scale_by_batch) end_value = _get_batch_scaled_lr( total_batch_size, optimizer_config.cosine_decay_kwargs.end_value, optimizer_config.scale_by_batch) schedule_fn = optax.warmup_cosine_decay_schedule( init_value=init_value, peak_value=base_lr, warmup_steps=warmup_steps, decay_steps=total_steps, end_value=end_value) elif schedule_type == 'constant_cosine': # Convert end_value to alpha, used by cosine_decay_schedule. alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr # Number of steps spent in constant phase. constant_steps = int( optimizer_config.constant_cosine_decay_kwargs.constant_fraction * total_steps) decay_steps = total_steps - constant_steps constant_phase = optax.constant_schedule(value=base_lr) decay_phase = optax.cosine_decay_schedule( init_value=base_lr, decay_steps=decay_steps, alpha=alpha) schedule_fn = optax.join_schedules( schedules=[constant_phase, decay_phase], boundaries=[constant_steps]) else: raise ValueError(f'Unknown learning rate schedule: {schedule_type}') return schedule_fn
def create_learning_rate_fn(workload: spec.Workload, hparams: spec.Hyperparameters): """Create learning rate schedule.""" warmup_fn = optax.linear_schedule(init_value=0., end_value=hparams.learning_rate, transition_steps=hparams.warmup_steps) cosine_fn = optax.cosine_decay_schedule(init_value=hparams.learning_rate, decay_steps=(workload.step_hint - hparams.warmup_steps)) schedule_fn = optax.join_schedules(schedules=[warmup_fn, cosine_fn], boundaries=[hparams.warmup_steps]) return schedule_fn
def get_cosine_schedule( max_learning_rate: float, total_steps: int, warmup_steps: int = 0) -> optax.Schedule: """Builds a cosine decay schedule with initial warm-up.""" if total_steps < warmup_steps: return optax.linear_schedule(init_value=0., end_value=max_learning_rate, transition_steps=warmup_steps) return optax.join_schedules([ optax.linear_schedule(init_value=0., end_value=max_learning_rate, transition_steps=warmup_steps), optax.cosine_decay_schedule(init_value=max_learning_rate, decay_steps=total_steps - warmup_steps), ], [warmup_steps])
def create_learning_rate_fn(hparams: spec.Hyperparamters, steps_per_epoch: int): """Create learning rate schedule.""" base_learning_rate = hparams.learning_rate * get_batch_size('imagenet') / 256. warmup_fn = optax.linear_schedule( init_value=0., end_value=base_learning_rate, transition_steps=hparams.warmup_epochs * steps_per_epoch) cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule( init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch) schedule_fn = optax.join_schedules( schedules=[warmup_fn, cosine_fn], boundaries=[hparams.warmup_epochs * steps_per_epoch]) return schedule_fn
def create_learning_rate_fn(config: ml_collections.ConfigDict, base_learning_rate: float, steps_per_epoch: int): """Create learning rate schedule.""" warmup_fn = optax.linear_schedule(init_value=0., end_value=base_learning_rate, transition_steps=config.warmup_epochs * steps_per_epoch) cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch) schedule_fn = optax.join_schedules( schedules=[warmup_fn, cosine_fn], boundaries=[config.warmup_epochs * steps_per_epoch]) return schedule_fn
def build_optimizer(lr, momentum, steps_per_epoch, n_epochs, nesterov, warmup_epochs=5): cosine_schedule = optax.cosine_decay_schedule(1, decay_steps=n_epochs * steps_per_epoch, alpha=1e-10) warmup_schedule = optax.polynomial_schedule( init_value=0.0, end_value=1.0, power=1, transition_steps=warmup_epochs * steps_per_epoch, ) schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x) ) optimizer = optax.sgd(lr, momentum, nesterov=nesterov) optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule)) return optimizer
def test_loop_over_loader(self, variant): x_train, _ = self.dataset.data['train'] batch_size = 2 epochs = 1 # set up optimizer and lr scheduler scheduler = 'Cosine' optimizer_params = {'learning_rate': 1e-3, 'weight_decay': 1e-4} lr_mult = 1.0 steps_per_epoch = float(x_train.shape[0] / batch_size) decay_steps = int((epochs + 1) * steps_per_epoch) print('steps_per_epoch: %s', str(steps_per_epoch)) print('decay_steps: %s', str(decay_steps)) cosine_scheduler_fn = optax.cosine_decay_schedule( init_value=optimizer_params['learning_rate'], decay_steps=decay_steps) optimizer_params['learning_rate'] = cosine_scheduler_fn print('optimizer_params: %s', str(optimizer_params)) optim = optax.adamw(**optimizer_params) optim_state = optim.init(self.sim) loop_over_loader_partial = functools.partial(loop_over_loader, optim=optim, rollout_fn=rollout, scheduler=scheduler) variant_loop_over_loader = variant(loop_over_loader_partial) prng_key = jax.random.PRNGKey(0) x, y, prng_key = get_shuffled_and_batched_data(self.dataset, batch_size, 'train', prng_key) (self.sim, optim_state, lr_mult, loss), _ = jax.lax.scan(variant_loop_over_loader, (self.sim, optim_state, lr_mult, 0.), (x, y)) print('test_loop_over_loader loss:' + str(loss)) self.assertTrue(jnp.allclose(float(loss), 0.89504164))
def train_controller( controller, sim, pip_feed="parallel", # or "sequential" mode="multipip", # or "singular" duration=0.87, dt=0.03, epochs=100, use_noise=False, optimizer=optax.adamw, optimizer_params={ "learning_rate": 1e-3, "weight_decay": 1e-4 }, loss_fn=lambda x, y: (jnp.abs(x - y)).mean(), scheduler="Cosine", tensorboard_dir=None, model_parameters={}, # used for tensorboard print_loss=1, ): """train controller.""" peep = 5 if mode == "multipip": pips = [10, 15, 20, 25, 30, 35] elif mode == "singular": pips = [35] # setup optimizer optim_params = copy.deepcopy(optimizer_params) if scheduler == "Cosine": if pip_feed == "parallel": steps_per_epoch = 1 elif pip_feed == "sequential": steps_per_epoch = len(pips) decay_steps = int(epochs * steps_per_epoch) print("steps_per_epoch:" + str(steps_per_epoch)) print("decay_steps:" + str(decay_steps)) cosine_scheduler_fn = optax.cosine_decay_schedule( init_value=optim_params["learning_rate"], decay_steps=decay_steps) optim_params["learning_rate"] = cosine_scheduler_fn print("optim_params:" + str(optim_params)) optim = optimizer(**optim_params) optim_state = optim.init(controller) # setup Tensorboard writer if tensorboard_dir is not None: trial_name = str(model_parameters) write_path = tensorboard_dir + trial_name summary_writer = metric_writers.create_default_writer( logdir=write_path, just_logging=jax.process_index() != 0) # summary_writer = tensorboard.SummaryWriter(write_path) summary_writer.write_hparams(model_parameters) tt = jnp.linspace(0, duration, int(duration / dt)) losses = [] for epoch in range(epochs): if pip_feed == "parallel": value, grad = jax.value_and_grad(rollout_parallel)(controller, sim, tt, use_noise, peep, jnp.array(pips), loss_fn) updates, optim_state = optim.update(grad, optim_state, controller) controller = optax.apply_updates(controller, updates) per_step_loss = value / len(tt) losses.append(per_step_loss) if epoch % print_loss == 0: # make new controller with trained parameters and normal clamp score = test_controller(controller, sim, pips, peep) print(f"Epoch: {epoch}\tLoss: {score:.2f}") if tensorboard_dir is not None: summary_writer.write_scalars(epoch, {"score": score}) if pip_feed == "sequential": for pip in pips: value, grad = jax.value_and_grad(rollout)(controller, sim, tt, use_noise, peep, pip, loss_fn, jnp.array(0.)) updates, optim_state = optim.update(grad, optim_state, controller) controller = optax.apply_updates(controller, updates) per_step_loss = value / len(tt) losses.append(per_step_loss) if epoch % print_loss == 0: # make new controller with trained parameters and normal clamp score = test_controller(controller, sim, pips, peep) print(f"Epoch: {epoch}, pip: {pip}\tLoss: {score:.2f}") if tensorboard_dir is not None: summary_writer.write_scalars(epoch, {"per_step_loss": score}) return controller, per_step_loss, score
def train_simulator( dataset, model, num_boundary_models, activation_fn_name, R, C, # idx 0 to num_boundary_models-1 are boundary models, # idx num_boundary_models is default_model train_key="train", test_key="test", batch_size=512, epochs=500, optimizer=optax.adamw, optimizer_params={ "learning_rate": 1e-3, "weight_decay": 1e-4 }, patience=10, lr_decay_factor=0.1, scheduler="ReduceLROnPlateau", # or "Cosine" loss_fn=lambda x, y: (jnp.abs(x - y)).mean(), print_loss=10, use_tensorboard=False, mode="train", user_name="alexjyu-brain", tb_dir=None, ): """train simulator.""" # evaluate on these at end of epoch for key in ["train", "test"]: dataset.data[key] = (jnp.array(dataset.data[key][0]), jnp.array(dataset.data[key][1])) X_train, y_train = dataset.data[train_key] X_test, y_test = dataset.data[test_key] # set up optimizer and lr scheduler lr_mult = 1.0 if scheduler == "ReduceLROnPlateau": optim = optimizer(**optimizer_params) patience_cnt = 0 prev_loss = float("inf") elif scheduler == "Cosine": steps_per_epoch = float(X_train.shape[0] / batch_size) decay_steps = int((epochs + 1) * steps_per_epoch) logging.info("steps_per_epoch: %s", str(steps_per_epoch)) logging.info("decay_steps: %s", str(decay_steps)) cosine_scheduler_fn = optax.cosine_decay_schedule( init_value=optimizer_params["learning_rate"], decay_steps=decay_steps) optimizer_params["learning_rate"] = cosine_scheduler_fn logging.info("optimizer_params: %s", str(optimizer_params)) optim = optimizer(**optimizer_params) optim_state = optim.init(model) loop_over_loader_partial = functools.partial( loop_over_loader, optim=optim, rollout_fn=rollout, scheduler=scheduler) # Tensorboard writer if use_tensorboard: config = copy.deepcopy(model.default_model_parameters) del config["activation_fn"] config["activation_fn_name"] = activation_fn_name if mode == "train": file_name = str(config) write_path = tb_dir + file_name summary_writer = metric_writers.create_default_writer( logdir=write_path, just_logging=jax.process_index() != 0) summary_writer = tensorboard.SummaryWriter(write_path) summary_writer.write_hparams(dict(config)) # Main Training Loop prng_key = jax.random.PRNGKey(0) for epoch in range(epochs + 1): if epoch % 10 == 0: logging.info("epoch: %s", str(epoch)) X, y, prng_key = get_shuffled_and_batched_data(dataset, batch_size, train_key, prng_key) if epoch == 0: logging.info("X.shape: %s", str(X.shape)) logging.info("y.shape: %s", str(y.shape)) (model, optim_state, lr_mult, loss), _ = jax.lax.scan(loop_over_loader_partial, (model, optim_state, lr_mult, 0.), (X, y)) """for i in range(X.shape[0]): carry = (model, optim_state, lr_mult, 0.) carry, _ = loop_over_loader_partial(carry, (X[i], y[i])) model, optim_state, lr_mult, loss = carry """ if scheduler == "ReduceLROnPlateau": if loss > prev_loss: patience_cnt = patience_cnt + 1 else: patience_cnt = 0 if patience_cnt == patience: lr_mult = lr_mult * lr_decay_factor patience_cnt = 0 prev_loss = loss if epoch % print_loss == 0: if scheduler == "ReduceLROnPlateau": logging.info("loss: %s", str(loss)) logging.info("prev_loss: %s", str(prev_loss)) logging.info("patience_cnt: %s", str(patience_cnt)) logging.info("lr_mult: %s", str(lr_mult)) # expensive end-of-epoch eval, just for intuition train_loss = map_rollout_over_batch(model, (X_train, y_train), rollout) # cross-validation test_loss = map_rollout_over_batch(model, (X_test, y_test), rollout) if epoch % print_loss == 0: logging.info( f"Epoch {epoch:2d}: train={train_loss.item():.5f}, test_loss={test_loss.item():.5f}" ) logging.info("-----------------------------------") if use_tensorboard: summary_writer.write_scalars(epoch, {"train_loss": train_loss}) summary_writer.write_scalars(epoch, {"test_loss": test_loss}) if use_tensorboard: summary_writer.flush() logging.info("finished looping over epochs") return model, test_loss
def train(base_dir, config): """Train function.""" print(config) chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train')) writer = create_default_writer() # Initialize dataset key = jax.random.PRNGKey(config.seed) key, subkey = jax.random.split(key) ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks) ds_iter = iter(ds) key, subkey = jax.random.split(key) encoder = MLPEncoder(**config.encoder) train_config = config.train.to_dict() train_method = train_config.pop('method') module_config = train_config.pop('module') module_class = module_config.pop('name') module = globals().get(module_class)(encoder, **module_config) train_step = globals().get(f'train_step_{train_method}') train_step = functools.partial(train_step, **train_config) params = module.init(subkey, next(ds_iter)[0]) lr = optax.cosine_decay_schedule(config.learning_rate, config.num_train_steps) optim = optax.chain(optax.adam(lr), # optax.adaptive_grad_clip(0.15) ) state = TrainState.create(apply_fn=module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) # Hooks report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = TrainMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_train_steps)): with jax.profiler.StepTraceAnnotation('train', step_num=step): states, targets = next(ds_iter) state, metrics = train_step(state, metrics, states, targets) logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = TrainMetrics.empty() # if step % config.log_eval_metrics_every == 0 and isinstance( # ds, dataset.MDPDataset): # eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config) # writer.write_scalars(step, eval_metrics.compute()) for hook in hooks: hook(step) chkpt_manager.save(state) return state
def evaluate(base_dir, config, *, train_state): """Eval function.""" chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval')) writer = create_default_writer() key = jax.random.PRNGKey(config.eval.seed) model_init_key, ds_key = jax.random.split(key) linear_module = LinearModule(config.eval.num_tasks) params = linear_module.init(model_init_key, jnp.zeros((config.encoder.embedding_dim, ))) lr = optax.cosine_decay_schedule(config.eval.learning_rate, config.num_eval_steps) optim = optax.adam(lr) ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks) ds_iter = iter(ds) state = TrainState.create(apply_fn=linear_module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_eval_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = EvalMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_eval_steps)): with jax.profiler.StepTraceAnnotation('eval', step_num=step): states, targets = next(ds_iter) state, metrics = evaluate_step(train_state, state, metrics, states, targets) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = EvalMetrics.empty() for hook in hooks: hook(step) # Finally, evaluate on the true(ish) test aux task matrix. states, targets = dataset.EvalDataset(config, ds_key).get_batch() @jax.jit def loss_fn(): outputs = train_state.apply_fn(train_state.params, states) phis = outputs.phi predictions = jax.vmap(state.apply_fn, in_axes=(None, 0))(state.params, phis) return jnp.mean(optax.l2_loss(predictions, targets)) test_loss = loss_fn() writer.write_scalars(config.num_eval_steps + 1, {'test_loss': test_loss})
def _create_jax_schedule(self): import optax return optax.cosine_decay_schedule(init_value=self.initial_rate, decay_steps=self.decay_steps, alpha=self.alpha)