def build_optimizer(self, clip=15.0, lr=5e-4, warmup=2000, cosine_decay_steps=None, optimizer_name="adabelief") -> GradientTransformation: chain = [] if optimizer_name == "adabelief": chain.append(util.scale_by_belief()) elif optimizer_name == "adam": chain.append(optax.scale_by_adam()) else: assert 0 # Make sure to use the negative learning rate so that we minimize if warmup and warmup > 0: warmup_schedule = partial(util.linear_warmup_lr_schedule, warmup=warmup, lr_decay=1.0, lr=-lr) chain.append(optax.scale_by_schedule(warmup_schedule)) else: chain.append(optax.scale(-lr)) if cosine_decay_steps and cosine_decay_steps > 0: cosine_lr = optax.cosine_decay_schedule( init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1) chain.append(optax.scale_by_schedule(cosine_lr)) if clip and clip > 0: chain.append(optax.clip(clip)) return optax.chain(*chain)
def make_optimizer(momentum=True, schedule_fn = lambda x:-1e-3): """SGD with momentum and a fixed lr.""" if momentum: return optax.chain( optax.trace(decay=0.9, nesterov=False), # momentum optax.scale_by_schedule(schedule_fn)) else: return optax.chain( optax.scale_by_schedule(schedule_fn))
def sgd_momentum(learning_rate_fn: optax.Schedule, momentum: float = 0., nesterov: bool = False) -> optax.GradientTransformation: return optax.chain( optax.trace(decay=momentum, nesterov=nesterov), optax.scale_by_schedule(learning_rate_fn), optax.scale(-1.))
def make_optimizer(): """SGD with nesterov momentum and a custom lr schedule.""" return optax.chain( optax.trace( decay=FLAGS.optimizer_momentum, nesterov=FLAGS.optimizer_use_nesterov), optax.scale_by_schedule(lr_schedule), optax.scale(-1))
def scaled_sgld(key: np.ndarray, schedule_fn: callable = optax.constant_schedule(1.)): """ Scale SGLD the correct way, using a custom schedule for the stepsize. an (init_fn, update_fn) Tuple""" scaler = optax.scale_by_schedule(schedule_fn) def init_fn(params): return ScaledSGLDState(count=0, key=key) def update_fn(updates, state, params=None): """ returns - stepsize * updates + np.sqrt(2 stepsize) * z, where z is standard normal. """ count, key = state stepsize = schedule_fn(count) count += 1 updates = jax.tree_map(lambda g: -stepsize * g, updates) key, subkey = random.split(key) # TODO: either throw error when stepsize < 0 or put np.abs(stepsize) # under the square root. return add_noise(subkey, updates, np.sqrt(2 * stepsize)), ScaledSGLDState(count=count, key=key) return optax.GradientTransformation(init_fn, update_fn)
def test_regularized_training(self): """Test that adding regularization penalty to the training loss works.""" np.random.seed(0) # Set up the problem of recovering w given x and # y = x . w + noise # with the a priori assumption that w is sparse. There are fewer examples # than dimensions (x is a wide matrix), so the problem is underdetermined # without the sparsity assumption. num_examples, num_dim = 8, 10 x = np.random.randn(num_examples, num_dim).astype(np.float32) true_w = np.zeros((num_dim, 2), np.float32) true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0] true_w[[3, 5], 1] = [4.0, 5.0] y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2) # Get the least squares estimate for w. It isn't very accurate. least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0] least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w) # Get a better estimate by solving the L1 regularized problem # argmin_w ||x . w - y||_2^2 + c ||w||_1. w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w) def model_fun(batch): x = batch['x'] return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x) model = hk_util.transform(model_fun) def loss_fun(params, batch): """Training loss with L1 regularization penalty term.""" y_predicted, penalties = model.apply(params, None, batch) return hk_util.l2_loss(y_predicted - batch['y']) + penalties batch = {'x': x, 'y': y} params = model.init(jax.random.PRNGKey(0), batch) optimizer = optax.chain( # Gradient descent with decreasing learning rate. optax.trace(decay=0.0, nesterov=False), optax.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i))) opt_state = optimizer.init(params) @jax.jit def train_step(params, opt_state, batch): grads = jax.grad(loss_fun)(params, batch) updates, opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, opt_state for _ in range(1000): params, opt_state = train_step(params, opt_state, batch) l1_w = params['linear']['w'] l1_w_error = hk_util.l2_loss(l1_w - true_w).item() # The L1-regularized estimate is much more accurate. self.assertGreater(least_squares_w_error, 4.0) self.assertLess(l1_w_error, 1.0)
def get_optimizer(config): warm_up_poly = optax.polynomial_schedule( init_value=1 / config['warmup_iter'], end_value=1, power=1, transition_steps=config['warmup_iter']) exp_decay = optax.exponential_decay( init_value=config['adam_lr'], transition_steps=config['decay_steps'], decay_rate=config['lr_decay_rate'], transition_begin=0) #config['warmup_iter']) opt = optax.chain( # clip_by_global_norm(max_norm), optax.scale_by_adam(b1=config['adam_beta_1'], b2=config['adam_beta_2'], eps=config['adam_eps']), optax.scale_by_schedule(warm_up_poly), optax.scale_by_schedule(exp_decay), optax.scale(-1)) return opt
def get(self) -> optax.GradientTransformation: if "adam" in self.optimizer: opt = optax.adam(self.base_learning_rate) elif "sgd" == self.optimizer and self.lr_schedule == "linear": lr_schedule = warm_up_polynomial_schedule( base_learning_rate=self.base_learning_rate, end_learning_rate=self.final_decay_factor * self.base_learning_rate, decay_steps=(self.n_batches * (self.epochs - self.lr_warmup_epochs)), warmup_steps=self.n_batches * self.lr_warmup_epochs, decay_power=1.0, ) momentum = 1 - self.one_minus_momentum opt = optax.chain( optax.trace(decay=momentum, nesterov=True), optax.scale_by_schedule(lr_schedule), optax.scale(-1), ) elif "sgd" in self.optimizer and self.lr_schedule == "step": lr_decay_epochs = [ (int(start_epoch_str) * self.epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in self.lr_decay_epochs ] lr_schedule = warm_up_piecewise_constant_schedule( steps_per_epoch=self.n_batches, base_learning_rate=self.base_learning_rate, decay_ratio=self.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=self.lr_warmup_epochs, ) momentum = 1 - self.one_minus_momentum opt = optax.chain( optax.trace(decay=momentum, nesterov=True), optax.scale_by_schedule(lr_schedule), optax.scale(-1), ) else: raise ValueError("No optimizer specified.") return opt
def _create_jax_optimizer(self): import optax process = [] if isinstance(self.learning_rate, LearningRateSchedule): scheduler = self.learning_rate._create_jax_schedule() process.append(optax.scale_by_schedule(scheduler)) last_process = optax.scale(-1.0) else: lr = self.learning_rate last_process = optax.scale(-1.0 * lr) process.append(last_process) return optax.chain(*process)
def build_model(params, tpu_name, region, preemptible, version=1): gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) cores_per_replica = params["cores_per_replica"] tpu_size = params["tpu_size"] warmup_steps = params["warmup_steps"] anneal_steps = params["anneal_steps"] lr = params["lr"] end_lr = params["end_lr"] weight_decay = params["weight_decay"] assert tpu_size in [8, 32, 128, 256, 512] create_tpu(tpu_name, region, f"v3-{tpu_size}", preemptible) assert wait_til(tpu_name, region, {'state': 'READY', 'health': 'HEALTHY'}) conns = get_connection(tpu_name, region) assert len(conns) * 8 == tpu_size, "wrong size TPU for config" head_info = ray.init(include_dashboard=False, object_store_memory=10**9) address = head_info['redis_address'] with multiprocessing.pool.ThreadPool(processes=len(conns)) as p: p.map(functools.partial(start_ray, address=address, version=version), conns) opt = optax.chain( optax.scale(1 / gradient_accumulation_steps), clip_by_global_norm(1), optax.scale_by_adam(), additive_weight_decay(weight_decay), optax.scale(-1), optax.scale_by_schedule(util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)) ) params["optimizer"] = opt if version == 2: model_fn = functools.partial(CausalTransformerV2, params) elif version == 1: model_fn = functools.partial(CausalTransformer, params) else: raise Exception(f"Version {version} does not exist") t = TPUCluster((tpu_size // cores_per_replica, cores_per_replica), len(conns), model_fn, version=version) return t
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay, max_norm): tx = optax.chain(optax.clip_by_global_norm(max_norm), optax.scale_by_adam(), optax.additive_weight_decay(weight_decay), optax.scale_by_schedule(lr_schedule_fn)) params = model.init(rng, jax.numpy.ones((1, img_size, img_size, 3)), is_training=False) train_state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, ) return train_state
def __init__(self, learning_rate_fn: Union[float, int, optax.Schedule], normalize_fn: Optional[NormalizeFn] = None): # Accept schedules, as well as scalar values. if isinstance(learning_rate_fn, (float, int)): lr = float(learning_rate_fn) learning_rate_fn = lambda _: lr # Normalization. def update_fn(updates, state, params=None): del params updates = jax.tree_map(normalize_fn or (lambda x: x), updates) return updates, state gradient_transformation = optax.chain( optax.GradientTransformation(lambda _: optax.EmptyState(), update_fn), optax.scale_by_schedule(learning_rate_fn), optax.scale(-1.)) super(SGD, self).__init__(gradient_transformation)
def __init__( self, *optimizer: optax.GradientTransformation, lr_schedule: tp.Optional[LRScheduler] = None, steps_per_epoch: tp.Union[int, jnp.ndarray, np.ndarray, None] = None, **kwargs, ): r""" Arguments: optimizer: An optax `GradientTransformation` object, if more than one is passed via `*args` then they are grouped using `optax.chain`. lr_schedule: A optional callable of the form `def lr_schedule(step: int, epoch: Optional[int]) -> float` that returns the learning rate schedule at each time step. If `steps_per_epoch` is given then epoch is calculated, else epoch is None. steps_per_epoch: The number of steps to in an epoch, needed to caculate `epoch` from `step`. """ if len(optimizer) == 0: raise ValueError("Must pass atleast 1 optimizer, got 0") elif lr_schedule is not None: # do this to preserve reference after re-assign latter base_schedule = lr_schedule def lr_schedule_(step: jnp.ndarray) -> jnp.ndarray: epoch: tp.Any = (step // steps_per_epoch if steps_per_epoch is not None else None) return base_schedule(step, epoch) optimizer = optax.chain( *optimizer, optax.scale_by_schedule(lr_schedule_), ) lr_schedule = lr_schedule_ elif len(optimizer) == 1: optimizer = optimizer[0] else: optimizer = optax.chain(*optimizer) self.optimizer = optimizer self.lr_schedule = lr_schedule
def _create_jax_optimizer(self): import optax process = [] if isinstance(self.learning_rate, LearningRateSchedule): scheduler = self.learning_rate._create_jax_schedule() process.append(optax.scale_by_schedule(scheduler)) last_process = optax.scale(-1.0) else: lr = self.learning_rate last_process = optax.scale(-1.0 * lr) process.append( optax.scale_by_adam(b1=self.beta1, b2=self.beta2, eps=self.epsilon, eps_root=0.0)) process.append(optax.add_decayed_weights(self.weight_decay, None)) process.append(last_process) return optax.chain(*process)
def _create_jax_optimizer(self): import optax process = [] if isinstance(self.learning_rate, LearningRateSchedule): scheduler = self.learning_rate._create_jax_schedule() process.append(optax.scale_by_schedule(scheduler)) last_process = optax.scale(-1.0) else: lr = self.learning_rate last_process = optax.scale(-1.0 * lr) process.append( optax.scale_by_rms(decay=self.decay, eps=self.epsilon, initial_scale=0.0)) if self.momentum is not None or self.momentum != 0.0: process.append(optax.trace(decay=self.momentum, nesterov=False)) process.append(last_process) return optax.chain(*process)
def make_learner( self, random_key: networks_lib.PRNGKey, networks: ppo_networks.PPONetworks, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec, replay_client if callable(self._config.learning_rate): optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.scale_by_adam(eps=self._config.adam_epsilon), optax.scale_by_schedule(self._config.learning_rate), optax.scale(-1)) else: optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.scale_by_adam(eps=self._config.adam_epsilon), optax.scale(-self._config.learning_rate)) return learning.PPOLearner( ppo_networks=networks, iterator=dataset, discount=self._config.discount, entropy_cost=self._config.entropy_cost, value_cost=self._config.value_cost, max_abs_reward=self._config.max_abs_reward, ppo_clipping_epsilon=self._config.ppo_clipping_epsilon, clip_value=self._config.clip_value, gae_lambda=self._config.gae_lambda, counter=counter, random_key=random_key, optimizer=optimizer, num_epochs=self._config.num_epochs, num_minibatches=self._config.num_minibatches, logger=logger_fn('learner'), )
def make_optimizer(optimizer_config, lr_schedule): """Construct the optax optimizer with given LR schedule.""" if (optimizer_config.get('decay_pos_embs') is None or optimizer_config.decay_pos_embs): # Decay learned position embeddings by default. weight_decay_exclude_names = ['b'] else: weight_decay_exclude_names = ['pos_embs', 'b'] optax_chain = [] if optimizer_config.max_norm > 0: optax_chain.append( optax.clip_by_global_norm(optimizer_config.max_norm)) if optimizer_config.optimizer == 'adam': # See: https://arxiv.org/abs/1412.6980 optax_chain.extend([ optax.scale_by_adam(**optimizer_config.adam_kwargs), add_weight_decay( optimizer_config.weight_decay, exclude_names=weight_decay_exclude_names) ]) elif optimizer_config.optimizer == 'lamb': # See: https://arxiv.org/abs/1904.00962 optax_chain.extend([ optax.scale_by_adam(**optimizer_config.lamb_kwargs), add_weight_decay( optimizer_config.weight_decay, exclude_names=weight_decay_exclude_names), optax.scale_by_trust_ratio() ]) else: raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}') # Scale by the (negative) learning rate. optax_chain.extend([ optax.scale_by_schedule(lr_schedule), optax.scale(-1), ]) return optax.chain(*optax_chain)
def build_optimizer(lr, momentum, steps_per_epoch, n_epochs, nesterov, warmup_epochs=5): cosine_schedule = optax.cosine_decay_schedule(1, decay_steps=n_epochs * steps_per_epoch, alpha=1e-10) warmup_schedule = optax.polynomial_schedule( init_value=0.0, end_value=1.0, power=1, transition_steps=warmup_epochs * steps_per_epoch, ) schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x) ) optimizer = optax.sgd(lr, momentum, nesterov=nesterov) optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule)) return optimizer
def make_optimizer(lr_schedule, momentum_decay): return optax.chain(optax.trace(decay=momentum_decay, nesterov=False), optax.scale_by_schedule(lr_schedule), optax.scale(-1))
total_steps = params["total_steps"] pe = params["pe"] assert pe in ["fixed", "rotary", "t5"] warmup_steps = params["warmup_steps"] anneal_steps = params["anneal_steps"] lr = params["lr"] end_lr = params["end_lr"] weight_decay = params["weight_decay"] opt = optax.chain( optax.scale(1 / gradient_accumulation_steps), clip_by_global_norm(1), optax.scale_by_adam(), additive_weight_decay(weight_decay), optax.scale(-1), optax.scale_by_schedule( util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr))) params["optimizer"] = opt start = time.time() tpu_size = jax.device_count() if tpu_size < cores_per_replica: msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})" raise ValueError(msg) print(f"jax devices: {tpu_size}") print(f"jax runtime initialized in {time.time() - start:.06}s") mesh_shape = (tpu_size // cores_per_replica, cores_per_replica) devices = np.array(jax.devices()).reshape(mesh_shape) # pick initial ckpt - based on tuning vs train from scratch
def main(_): # Create the dataset. tokenizer = utils.init_tokenizer(FLAGS.dataset) graph_tokenizer = utils.init_graph_tokenizer() dataset_class = utils.get_dataset_class(FLAGS.dataset, FLAGS.model_type) has_graph = True if FLAGS.model_type == 'graph2text' else False local_devices = jax.local_devices() num_gpus = min(FLAGS.num_gpus, len(local_devices)) if FLAGS.job_mode == 'train': train_dataset = dataset_class(tokenizer=tokenizer, graph_tokenizer=graph_tokenizer, batch_size=FLAGS.train_batch_size, subset='train', timesteps=FLAGS.train_timesteps, version=FLAGS.graph_data_version, shuffle_data=True, repeat=True, debug=FLAGS.debug) train_iter = iter(train_dataset) loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size, cache_steps=FLAGS.train_memory_size) optimizer = optax.chain( optax.clip_by_global_norm(FLAGS.grad_clip), optax.scale_by_adam(), optax.scale_by_schedule( functools.partial(utils.schedule, lr_schedule=FLAGS.lr_schedule, init_lr=FLAGS.init_lr, min_lr_ratio=FLAGS.min_lr_ratio, max_steps=FLAGS.max_steps)), optax.scale(-1)) optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5) updater = Updater(loss_fn, optimizer, devices=local_devices[:num_gpus], has_graph=has_graph) updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) _train(updater, train_iter, num_gpus) elif FLAGS.job_mode == 'eval': eval_dataset = dataset_class(tokenizer=tokenizer, graph_tokenizer=graph_tokenizer, batch_size=FLAGS.eval_batch_size, subset=FLAGS.eval_subset, timesteps=FLAGS.eval_timesteps, version=FLAGS.graph_data_version, shuffle_data=False, repeat=False, debug=FLAGS.debug) eval_iter = iter(eval_dataset) loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size, cache_steps=FLAGS.eval_memory_size) # only use one device for evaluation devices = local_devices[:1] updater = Updater(loss_fn, optimizer=None, devices=devices, has_graph=has_graph) updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) _eval(updater, eval_iter) elif FLAGS.job_mode == 'sample': eval_dataset = dataset_class(tokenizer=tokenizer, graph_tokenizer=graph_tokenizer, batch_size=1, subset=FLAGS.eval_subset, timesteps=FLAGS.sample_length, version=FLAGS.graph_data_version, shuffle_data=False, repeat=True, debug=FLAGS.debug) eval_iter = iter(eval_dataset) _sample(eval_iter, tokenizer, local_devices[:num_gpus]) elif FLAGS.job_mode == 'retrieve': eval_dataset = dataset_class(tokenizer=tokenizer, graph_tokenizer=graph_tokenizer, batch_size=1, subset=FLAGS.eval_subset, timesteps=FLAGS.eval_timesteps, version=FLAGS.graph_data_version, shuffle_data=False, repeat=False, graph_retrieval_dataset=True, debug=FLAGS.debug) eval_iter = iter(eval_dataset) loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size, cache_steps=FLAGS.eval_memory_size) # only use one device for evaluation devices = local_devices[:1] updater = Updater(loss_fn, optimizer=None, devices=devices, has_graph=has_graph) updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) _retrieve(updater, eval_iter)
def make_adam_optimizer(lr_schedule, b1=0.9, b2=0.999, eps=1e-8): """Make Adam optimizer.""" # Maximize log-prob instead of minimizing loss return optax.chain(optax.scale_by_adam(b1=b1, b2=b2, eps=eps), optax.scale_by_schedule(lr_schedule))
def make_sgd_optimizer(lr_schedule, momentum_decay): """Make SGD optimizer with momentum.""" # Maximize log-prob instead of minimizing loss return optax.chain(optax.trace(decay=momentum_decay, nesterov=False), optax.scale_by_schedule(lr_schedule))
def _scale_by_learning_rate(learning_rate, flip_sign=True): m = -1 if flip_sign else 1 if callable(learning_rate): return optax.scale_by_schedule(lambda count: m * learning_rate(count)) return optax.scale(m * learning_rate)
def main(_): # Make the network model = hk.without_apply_rng(hk.transform_with_state(forward_fn)) if FLAGS.spectral_norm > 0: sn_fn = hk.transform_with_state( lambda x: SNParamsTree(ignore_regex='[^?!.]*b$|[^?!.]*offset$', val=FLAGS.spectral_norm)(x)) # Initialisation optimizer = optax.chain(optax.adam(learning_rate=FLAGS.learning_rate), optax.scale_by_schedule(lr_schedule)) rng_seq = hk.PRNGSequence(42) if FLAGS.gaussian_prior: last_dim = 2 else: last_dim = 1 params, state = model.init(next(rng_seq), jnp.zeros((1, FLAGS.map_size, FLAGS.map_size, last_dim)), jnp.zeros((1, 1, 1, 1)), is_training=True) opt_state = optimizer.init(params) if FLAGS.spectral_norm > 0: _, sn_state = sn_fn.init(next(rng_seq), params) else: sn_state = 0. # If the Gaussian prior is used, load the theoretical power spectrum pixel_size = jnp.pi * FLAGS.resolution / 180. / 60. #rad/pixel if FLAGS.gaussian_prior: ps_data = onp.load(FLAGS.gaussian_path).astype('float32') ell = jnp.array(ps_data[0, :]) # massivenu: channel 4 ps_halofit = jnp.array(ps_data[1, :] / pixel_size**2) # normalisation by pixel size # convert to pixel units of our simple power spectrum calculator kell = ell / 2 / jnp.pi * 360 * pixel_size / FLAGS.map_size # Interpolate the Power Spectrum in Fourier Space power_map = jnp.array( make_power_map(ps_halofit, FLAGS.map_size, kps=kell)) def score_fn(params, state, batch, is_training=True): if FLAGS.gaussian_prior: # If requested, first compute the Gaussian prior gaussian_score = gaussian_prior_score(batch['y'][..., 0], batch['s'][..., 0], power_map) gaussian_score = jnp.expand_dims(gaussian_score, axis=-1) net_input = jnp.concatenate( [batch['y'], jnp.abs(batch['s'])**2 * gaussian_score], axis=-1) res, state = model.apply(params, state, net_input, batch['s'], is_training=is_training) else: res, state = model.apply(params, state, batch['y'], batch['s'], is_training=is_training) gaussian_score = jnp.zeros_like(res) return batch, res, state, gaussian_score # Training loss def loss_fn(params, state, rng_key, batch): _, res, state, gaussian_score = score_fn(params, state, batch) loss = jnp.mean((batch['u'] + batch['s'] * (res + gaussian_score))**2) return loss, state @jax.jit def update(params, state, sn_state, rng_key, opt_state, batch): (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, state, rng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) if FLAGS.spectral_norm > 0: new_params, new_sn_state = sn_fn.apply(None, sn_state, None, new_params) else: new_sn_state = sn_state return loss, new_params, state, new_sn_state, new_opt_state train = load_dataset(FLAGS.dataset, FLAGS.batch_size, FLAGS.map_size, FLAGS.noise_dist_std, FLAGS.train_split) summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir) print('training begins') for step in range(FLAGS.training_steps): loss, params, state, sn_state, opt_state = update( params, state, sn_state, next(rng_seq), opt_state, next(train)) if step % 50 == 0: summary_writer.scalar('train_loss', loss, step) print(step, loss) if step % 500 == 0: # Running denoiser on a batch of images batch, res, _, gs = score_fn(params, state, next(train), is_training=False) summary_writer.image( 'score/target', onp.clip(batch['x'][0, :, :, 0], 0, 0.1) * 10., step) summary_writer.image( 'score/input', onp.clip(batch['y'][0, :, :, 0], 0, 0.1) * 10., step) summary_writer.image('score/score', res[0, :, :, 0] + gs[0, :, :, 0], step) summary_writer.image( 'score/denoised', onp.clip( batch['y'][0, :, :, 0] + batch['s'][0, :, :, 0]**2 * (res[0, :, :, 0] + gs[0, :, :, 0]), 0, 0.1) * 10., step) summary_writer.image( 'score/gaussian_denoised', onp.clip( batch['y'][0, :, :, 0] + batch['s'][0, :, :, 0]**2 * gs[0, :, :, 0], 0, 0.1) * 10., step) print(step) if step % 5000 == 0: with open(FLAGS.output_dir + '/model-%d.pckl' % step, 'wb') as file: pickle.dump([params, state, sn_state], file) summary_writer.flush() with open(FLAGS.output_dir + '/model-final.pckl', 'wb') as file: pickle.dump([params, state, sn_state], file)
def solve_sdp_dual(verif_instance, key=None, opt=None, num_steps=10000, verbose=False, eval_every=1000, use_exact_eig_eval=True, use_exact_eig_train=False, n_iter_lanczos=30, scl=-1.0, lr_init=1e-3, steps_per_anneal=100, anneal_factor=1.0, num_anneals=3, opt_name='adam', gd_momentum=0.9, add_diagnostic_stats=False, opt_multiplier_fn=None, init_dual_vars=None, init_opt_state=None, opt_dual_vars=None, kappa_reg_weight=None, kappa_zero_after=None, device_type=None, save_best_k=1, include_opt_state=False): # pylint: disable=g-doc-return-or-yield, g-doc-args """Compute verified lower bound via dual of SDP relaxation. NOTE: This method exposes many hyperparameter options, and the method signature is subject to change. We instead suggest using ``solve_sdp_dual_simple`` instead if you need a stable interface. """ # NB: Whereas the rest of the code in this library is fairly top-down # readable, avoids excessive `if` statements, tries to make the code look # like the formalism, etc, this is not the case for this method. # This is essentially the outer loop, and includes all the debugging/logging/ # optimization tricks we need to get/debug good results. # # NB: Time profiling: On toy VerifInstances, JIT compilation dominates time # cost: JIT compilation takes ~12s, then we do ~3000 steps/sec. assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type' assert isinstance(verif_instance, utils.SdpDualVerifInstance), 'invalid type' key = key if key is not None else jax.random.PRNGKey(0) dual_vars = jax.tree_map(lambda s: None if s is None else jnp.zeros(s), verif_instance.dual_shapes) dual_vars = init_duals_ibp(verif_instance, dual_vars) if init_dual_vars is not None: # Casting, here for Colab. Essentially same as `dual_vars = init_dual_vars` dual_vars = utils.structure_like(init_dual_vars, dual_vars) if opt_dual_vars is not None: opt_dual_vars = utils.structure_like(opt_dual_vars, dual_vars) # Create optimizer if opt is None: if (isinstance(steps_per_anneal, float) or isinstance(steps_per_anneal, int)): anneal_steps = [ steps_per_anneal * (i + 1) for i in range(num_anneals) ] else: anneal_steps = np.cumsum(steps_per_anneal) anneal_steps = jnp.array(anneal_steps) def lr_schedule(t): cur_epoch = jnp.minimum(num_anneals, jnp.sum(t > anneal_steps)) return lr_init * jnp.float_power(anneal_factor, cur_epoch) opt_class = getattr(optax, opt_name) base_opt = (opt_class(1., momentum=gd_momentum) if opt_name == 'sgd' else opt_class(1.)) opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule)) if opt_multiplier_fn: # NB: Interface very specific to tree.map_structure_with_path # Example: opt_multiplier_fn=lambda path: 0.1 if 'lam' in path else 1.0 opt_multipliers = tree.map_structure_with_path( lambda path, v: opt_multiplier_fn(path), dual_vars) opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule), utils.scale_by_variable_opt(opt_multipliers)) else: opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule)) # Define loss function def loss(dual_vars, loss_scl=scl, exact=use_exact_eig_train): return _loss(dual_vars, loss_scl, exact) @functools.partial(jax.jit, static_argnums=(1, 2), backend=device_type) def _loss(dual_var, loss_scl, exact): loss_val, step_info = dual_fun(verif_instance, dual_var, key, n_iter=n_iter_lanczos, exact=exact, scl=loss_scl, include_info=True) step_info['loss_val'] = loss_val return loss_val, step_info # Define a compiled update step grad = jax.jit(jax.grad(loss, has_aux=True), backend=device_type) @functools.partial(jax.jit, backend=device_type) def grad_step(params, opt_state): g, info = grad(params) updates, new_opt_state = opt.update(g, opt_state) new_params = optax.apply_updates(params, updates) info['g'] = g info['updates'] = updates return new_params, new_opt_state, info # Optimize parameters in a loop opt_state = opt.init(dual_vars) if init_opt_state: opt_state = utils.structure_like(init_opt_state, opt_state) info = collections.defaultdict(list) loss_log = [] store_best = [] recent_eig_vecs = collections.deque(maxlen=10) best_loss = 1e9 last_H = None start_i = 0 # Main loop for i in range(start_i, num_steps): dual_vars_prev = dual_vars dual_vars, opt_state, step_info = grad_step(dual_vars, opt_state) loss_val = step_info['loss_val'] print(f'Iter {i}: Loss {loss_val}') best_loss = min(best_loss, loss_val) if add_diagnostic_stats: info['dual_vars'].append(dual_vars_prev) eig_vec = step_info['eig_vec'] cosine_sims = [] for prev_eig_vec in recent_eig_vecs: denom = jnp.sqrt( jnp.linalg.norm(eig_vec) * jnp.linalg.norm(prev_eig_vec)) eig_sim = jnp.sum(prev_eig_vec * eig_vec) / denom cosine_sims.append(abs(float(eig_sim))) info['c_lambda'].append(float(step_info['c_lambda'])) info['past_10_cosine_sims'].append(np.array(cosine_sims)) info['g'].append(step_info['g']) info['updates'].append(step_info['updates']) if use_exact_eig_train: # The info is for -H, so to get smallest for H, take negative of max eig_vals = -step_info['eig_info'][0][-1:-20:-1] cur_H = step_info['eig_info'][2] diff_H = 0 if last_H is None else np.linalg.norm(cur_H - last_H) last_H = cur_H info['diff_H'].append(float(diff_H)) info['smallest_20_eig_vals'].append(eig_vals) recent_eig_vecs.appendleft(eig_vec) loss_log.append(loss_val) if len(store_best) < save_best_k: store_best.append((loss_val, dual_vars_prev)) store_best.sort(key=lambda x: x[0]) elif loss_val < store_best[-1][0]: store_best[-1] = (loss_val, dual_vars_prev) store_best.sort(key=lambda x: x[0]) # Regularization of kappa if kappa_reg_weight is not None and kappa_reg_weight >= 0: onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1]) mask = jnp.ones_like(onehot) - onehot dual_vars[-1] -= mask * kappa_reg_weight if (kappa_zero_after is not None and kappa_zero_after >= 0 and i > kappa_zero_after): onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1]) dual_vars[-1] *= onehot dual_vars = project_duals(dual_vars, verif_instance.dual_types) if opt_dual_vars: distance_to_opt = jax.tree_multimap( lambda x, y: jnp.linalg.norm(x - y), dual_vars, opt_dual_vars) info['distance_to_opt'].append(distance_to_opt) if i % eval_every == 0: dual_val, _ = loss(dual_vars, loss_scl=-1, exact=use_exact_eig_eval) info['steps'].append(i) info['loss_vals'].append(float(dual_val)) if verbose: print(f'Dual iter {i}: Train loss: {loss_val} Loss {dual_val}') final_loss = float( loss(dual_vars, loss_scl=-1, exact=use_exact_eig_eval)[0]) info['final_dual_vars'] = dual_vars info['final_opt_state'] = opt_state info['final_loss'] = final_loss info['loss_log'] = loss_log info['store_best'] = store_best if include_opt_state: return final_loss, info, opt_state else: return final_loss, info
assert cores_per_replica <= 8 bucket = params["bucket"] model_dir = params["model_dir"] layers = params["layers"] d_model = params["d_model"] n_heads = params["n_heads"] n_vocab = params["n_vocab"] seq = params["seq"] norm = params["norm"] params["sampler"] = nucleaus_sample opt = optax.chain(optax.scale(1 / gradient_accumulation_steps), clip_by_global_norm(1), optax.scale_by_adam(), optax.additive_weight_decay(0), optax.scale(-1), optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))) params["optimizer"] = opt start = time.time() print(f"jax devices: {jax.device_count()}") print(f"jax runtime initialized in {time.time() - start:.06}s") mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) devices = np.array(jax.devices()).reshape(mesh_shape) with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f: meta = json.load(f) ckpt_step = meta["checkpoints"][-1] print(f"using checkpoint {ckpt_step}")
def scale_by_learning_rate(learning_rate: ScalarOrSchedule): if callable(learning_rate): return optax.scale_by_schedule(lambda count: -learning_rate(count)) return optax.scale(-learning_rate)
def create_optimizer(config): """Creates the optimizer associated to a config.""" ops = [] # Gradient clipping either by norm `gradient_norm_clip` or by absolute value # `gradient_value_clip`. if "gradient_clip" in config: raise ValueError("'gradient_clip' is deprecated, please use " "'gradient_norm_clip'.") assert not ("gradient_norm_clip" in config and "gradient_value_clip" in config), ( "Gradient clipping by norm and by value are exclusive.") if "gradient_norm_clip" in config: ops.append(optax.clip_by_global_norm(config.gradient_norm_clip)) if "gradient_value_clip" in config: ops.append(optax.clip(config.gradient_value_clip)) # Define the learning rate schedule. schedule_fn = utils.get_optax_schedule_fn( warmup_ratio=config.get("warmup_ratio", 0.), num_train_steps=config.num_train_steps, decay=config.get("learning_rate_step_decay", 1.0), decay_at_steps=config.get("learning_rate_decay_at_steps", []), cosine_decay_schedule=config.get("cosine_decay", False)) schedule_ops = [optax.scale_by_schedule(schedule_fn)] # Scale some parameters matching a regex by a multiplier. Config field # `scaling_by_regex` is a list of pairs (regex: str, multiplier: float). scaling_by_regex = config.get("scaling_learning_rate_by_regex", []) for regex, multiplier in scaling_by_regex: logging.info( "Learning rate is scaled by %f for parameters matching '%s'", multiplier, regex) schedule_ops.append(utils.scale_selected_parameters(regex, multiplier)) schedule_optimizer = optax.chain(*schedule_ops) if config.optimizer.lower() == "adam": optimizer = optax.adam(config.learning_rate) ops.append(optimizer) ops.append(schedule_optimizer) elif config.optimizer.lower() == "sgd": ops.append(schedule_optimizer) optimizer = optax.sgd(config.learning_rate, momentum=config.momentum) ops.append(optimizer) else: raise NotImplementedError("Invalid optimizer: {}".format( config.optimizer)) if "weight_decay" in config and config.weight_decay > 0.: ops.append( utils.decoupled_weight_decay(decay=config.weight_decay, step_size_fn=schedule_fn)) # Freeze parameters that match the given regexes (if any). freeze_weights_regexes = config.get("freeze_weights_regex", []) or [] if isinstance(freeze_weights_regexes, str): freeze_weights_regexes = [freeze_weights_regexes] for reg in freeze_weights_regexes: ops.append(utils.freeze(reg)) return optax.chain(*ops)
def update( self, gradient: Weights, state: GenericGradientState, parameters: Optional[Weights] ) -> Tuple[Weights, GenericGradientState]: return GenericGradientState.wrap(*scale_by_schedule( self.step_size_fn).update(gradient, state.data, parameters))