Esempio n. 1
0
    def build_optimizer(self,
                        clip=15.0,
                        lr=5e-4,
                        warmup=2000,
                        cosine_decay_steps=None,
                        optimizer_name="adabelief") -> GradientTransformation:
        chain = []
        if optimizer_name == "adabelief":
            chain.append(util.scale_by_belief())
        elif optimizer_name == "adam":
            chain.append(optax.scale_by_adam())
        else:
            assert 0

        # Make sure to use the negative learning rate so that we minimize
        if warmup and warmup > 0:
            warmup_schedule = partial(util.linear_warmup_lr_schedule,
                                      warmup=warmup,
                                      lr_decay=1.0,
                                      lr=-lr)
            chain.append(optax.scale_by_schedule(warmup_schedule))
        else:
            chain.append(optax.scale(-lr))

        if cosine_decay_steps and cosine_decay_steps > 0:
            cosine_lr = optax.cosine_decay_schedule(
                init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1)
            chain.append(optax.scale_by_schedule(cosine_lr))

        if clip and clip > 0:
            chain.append(optax.clip(clip))

        return optax.chain(*chain)
Esempio n. 2
0
def make_optimizer(momentum=True, schedule_fn = lambda x:-1e-3):
    """SGD with momentum and a fixed lr."""
    if momentum:
        return optax.chain(
            optax.trace(decay=0.9, nesterov=False),  # momentum
            optax.scale_by_schedule(schedule_fn))
    else:
        return optax.chain(
            optax.scale_by_schedule(schedule_fn))
Esempio n. 3
0
def sgd_momentum(learning_rate_fn: optax.Schedule,
                 momentum: float = 0.,
                 nesterov: bool = False) -> optax.GradientTransformation:
  return optax.chain(
      optax.trace(decay=momentum, nesterov=nesterov),
      optax.scale_by_schedule(learning_rate_fn),
      optax.scale(-1.))
Esempio n. 4
0
def make_optimizer():
  """SGD with nesterov momentum and a custom lr schedule."""
  return optax.chain(
      optax.trace(
          decay=FLAGS.optimizer_momentum,
          nesterov=FLAGS.optimizer_use_nesterov),
      optax.scale_by_schedule(lr_schedule), optax.scale(-1))
Esempio n. 5
0
def scaled_sgld(key: np.ndarray,
                schedule_fn: callable = optax.constant_schedule(1.)):
    """
    Scale SGLD the correct way, using a custom schedule for the stepsize.

        an (init_fn, update_fn) Tuple"""
    scaler = optax.scale_by_schedule(schedule_fn)

    def init_fn(params):
        return ScaledSGLDState(count=0, key=key)

    def update_fn(updates, state, params=None):
        """
        returns
        - stepsize * updates + np.sqrt(2 stepsize) * z,
        where z is standard normal.
        """
        count, key = state
        stepsize = schedule_fn(count)
        count += 1
        updates = jax.tree_map(lambda g: -stepsize * g, updates)
        key, subkey = random.split(key)
        # TODO: either throw error when stepsize < 0 or put np.abs(stepsize)
        # under the square root.
        return add_noise(subkey, updates,
                         np.sqrt(2 * stepsize)), ScaledSGLDState(count=count,
                                                                 key=key)

    return optax.GradientTransformation(init_fn, update_fn)
Esempio n. 6
0
  def test_regularized_training(self):
    """Test that adding regularization penalty to the training loss works."""
    np.random.seed(0)
    # Set up the problem of recovering w given x and
    #   y = x . w + noise
    # with the a priori assumption that w is sparse. There are fewer examples
    # than dimensions (x is a wide matrix), so the problem is underdetermined
    # without the sparsity assumption.
    num_examples, num_dim = 8, 10
    x = np.random.randn(num_examples, num_dim).astype(np.float32)
    true_w = np.zeros((num_dim, 2), np.float32)
    true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0]
    true_w[[3, 5], 1] = [4.0, 5.0]
    y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2)

    # Get the least squares estimate for w. It isn't very accurate.
    least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0]
    least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w)

    # Get a better estimate by solving the L1 regularized problem
    #  argmin_w ||x . w - y||_2^2 + c ||w||_1.
    w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w)
    def model_fun(batch):
      x = batch['x']
      return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x)

    model = hk_util.transform(model_fun)

    def loss_fun(params, batch):
      """Training loss with L1 regularization penalty term."""
      y_predicted, penalties = model.apply(params, None, batch)
      return hk_util.l2_loss(y_predicted - batch['y']) + penalties

    batch = {'x': x, 'y': y}
    params = model.init(jax.random.PRNGKey(0), batch)
    optimizer = optax.chain(  # Gradient descent with decreasing learning rate.
        optax.trace(decay=0.0, nesterov=False),
        optax.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i)))
    opt_state = optimizer.init(params)

    @jax.jit
    def train_step(params, opt_state, batch):
      grads = jax.grad(loss_fun)(params, batch)
      updates, opt_state = optimizer.update(grads, opt_state)
      new_params = optax.apply_updates(params, updates)
      return new_params, opt_state

    for _ in range(1000):
      params, opt_state = train_step(params, opt_state, batch)

    l1_w = params['linear']['w']
    l1_w_error = hk_util.l2_loss(l1_w - true_w).item()

    # The L1-regularized estimate is much more accurate.
    self.assertGreater(least_squares_w_error, 4.0)
    self.assertLess(l1_w_error, 1.0)
def get_optimizer(config):
    warm_up_poly = optax.polynomial_schedule(
        init_value=1 / config['warmup_iter'],
        end_value=1,
        power=1,
        transition_steps=config['warmup_iter'])
    exp_decay = optax.exponential_decay(
        init_value=config['adam_lr'],
        transition_steps=config['decay_steps'],
        decay_rate=config['lr_decay_rate'],
        transition_begin=0)  #config['warmup_iter'])
    opt = optax.chain(
        # clip_by_global_norm(max_norm),
        optax.scale_by_adam(b1=config['adam_beta_1'],
                            b2=config['adam_beta_2'],
                            eps=config['adam_eps']),
        optax.scale_by_schedule(warm_up_poly),
        optax.scale_by_schedule(exp_decay),
        optax.scale(-1))
    return opt
Esempio n. 8
0
    def get(self) -> optax.GradientTransformation:
        if "adam" in self.optimizer:
            opt = optax.adam(self.base_learning_rate)
        elif "sgd" == self.optimizer and self.lr_schedule == "linear":
            lr_schedule = warm_up_polynomial_schedule(
                base_learning_rate=self.base_learning_rate,
                end_learning_rate=self.final_decay_factor *
                self.base_learning_rate,
                decay_steps=(self.n_batches *
                             (self.epochs - self.lr_warmup_epochs)),
                warmup_steps=self.n_batches * self.lr_warmup_epochs,
                decay_power=1.0,
            )
            momentum = 1 - self.one_minus_momentum
            opt = optax.chain(
                optax.trace(decay=momentum, nesterov=True),
                optax.scale_by_schedule(lr_schedule),
                optax.scale(-1),
            )
        elif "sgd" in self.optimizer and self.lr_schedule == "step":
            lr_decay_epochs = [
                (int(start_epoch_str) * self.epochs) // DEFAULT_NUM_EPOCHS
                for start_epoch_str in self.lr_decay_epochs
            ]
            lr_schedule = warm_up_piecewise_constant_schedule(
                steps_per_epoch=self.n_batches,
                base_learning_rate=self.base_learning_rate,
                decay_ratio=self.lr_decay_ratio,
                decay_epochs=lr_decay_epochs,
                warmup_epochs=self.lr_warmup_epochs,
            )

            momentum = 1 - self.one_minus_momentum
            opt = optax.chain(
                optax.trace(decay=momentum, nesterov=True),
                optax.scale_by_schedule(lr_schedule),
                optax.scale(-1),
            )
        else:
            raise ValueError("No optimizer specified.")
        return opt
Esempio n. 9
0
 def _create_jax_optimizer(self):
     import optax
     process = []
     if isinstance(self.learning_rate, LearningRateSchedule):
         scheduler = self.learning_rate._create_jax_schedule()
         process.append(optax.scale_by_schedule(scheduler))
         last_process = optax.scale(-1.0)
     else:
         lr = self.learning_rate
         last_process = optax.scale(-1.0 * lr)
     process.append(last_process)
     return optax.chain(*process)
Esempio n. 10
0
def build_model(params, tpu_name, region, preemptible, version=1):
    gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
    cores_per_replica = params["cores_per_replica"]
    tpu_size = params["tpu_size"]

    warmup_steps = params["warmup_steps"]
    anneal_steps = params["anneal_steps"]
    lr = params["lr"]
    end_lr = params["end_lr"]
    weight_decay = params["weight_decay"]

    assert tpu_size in [8, 32, 128, 256, 512]

    create_tpu(tpu_name, region, f"v3-{tpu_size}", preemptible)
    assert wait_til(tpu_name, region, {'state': 'READY', 'health': 'HEALTHY'})

    conns = get_connection(tpu_name, region)

    assert len(conns) * 8 == tpu_size, "wrong size TPU for config"

    head_info = ray.init(include_dashboard=False, object_store_memory=10**9)
    address = head_info['redis_address']

    with multiprocessing.pool.ThreadPool(processes=len(conns)) as p:
        p.map(functools.partial(start_ray, address=address, version=version), conns)

    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps),
        clip_by_global_norm(1),
        optax.scale_by_adam(),
        additive_weight_decay(weight_decay),
        optax.scale(-1),
        optax.scale_by_schedule(util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr))
    )

    params["optimizer"] = opt

    if version == 2:
        model_fn = functools.partial(CausalTransformerV2, params)
    elif version == 1:
        model_fn = functools.partial(CausalTransformer, params)
    else:
        raise Exception(f"Version {version} does not exist")

    t = TPUCluster((tpu_size // cores_per_replica, cores_per_replica), len(conns), model_fn, version=version)
    return t
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay,
                       max_norm):

    tx = optax.chain(optax.clip_by_global_norm(max_norm),
                     optax.scale_by_adam(),
                     optax.additive_weight_decay(weight_decay),
                     optax.scale_by_schedule(lr_schedule_fn))

    params = model.init(rng,
                        jax.numpy.ones((1, img_size, img_size, 3)),
                        is_training=False)

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )
    return train_state
Esempio n. 12
0
    def __init__(self,
                 learning_rate_fn: Union[float, int, optax.Schedule],
                 normalize_fn: Optional[NormalizeFn] = None):
        # Accept schedules, as well as scalar values.
        if isinstance(learning_rate_fn, (float, int)):
            lr = float(learning_rate_fn)
            learning_rate_fn = lambda _: lr
        # Normalization.
        def update_fn(updates, state, params=None):
            del params
            updates = jax.tree_map(normalize_fn or (lambda x: x), updates)
            return updates, state

        gradient_transformation = optax.chain(
            optax.GradientTransformation(lambda _: optax.EmptyState(),
                                         update_fn),
            optax.scale_by_schedule(learning_rate_fn), optax.scale(-1.))
        super(SGD, self).__init__(gradient_transformation)
Esempio n. 13
0
    def __init__(
        self,
        *optimizer: optax.GradientTransformation,
        lr_schedule: tp.Optional[LRScheduler] = None,
        steps_per_epoch: tp.Union[int, jnp.ndarray, np.ndarray, None] = None,
        **kwargs,
    ):
        r"""
        Arguments:
            optimizer: An optax `GradientTransformation` object, if more than one is passed via `*args` then they are
                grouped using `optax.chain`.
            lr_schedule: A optional callable of the form `def lr_schedule(step: int, epoch: Optional[int]) -> float` that
                returns the learning rate schedule at each time step. If `steps_per_epoch` is given then epoch is calculated,
                else epoch is None.
            steps_per_epoch: The number of steps to in an epoch, needed to caculate `epoch` from `step`.
        """

        if len(optimizer) == 0:
            raise ValueError("Must pass atleast 1 optimizer, got 0")

        elif lr_schedule is not None:
            # do this to preserve reference after re-assign latter
            base_schedule = lr_schedule

            def lr_schedule_(step: jnp.ndarray) -> jnp.ndarray:
                epoch: tp.Any = (step // steps_per_epoch
                                 if steps_per_epoch is not None else None)

                return base_schedule(step, epoch)

            optimizer = optax.chain(
                *optimizer,
                optax.scale_by_schedule(lr_schedule_),
            )

            lr_schedule = lr_schedule_

        elif len(optimizer) == 1:
            optimizer = optimizer[0]
        else:
            optimizer = optax.chain(*optimizer)

        self.optimizer = optimizer
        self.lr_schedule = lr_schedule
Esempio n. 14
0
    def _create_jax_optimizer(self):
        import optax
        process = []
        if isinstance(self.learning_rate, LearningRateSchedule):
            scheduler = self.learning_rate._create_jax_schedule()
            process.append(optax.scale_by_schedule(scheduler))
            last_process = optax.scale(-1.0)
        else:
            lr = self.learning_rate
            last_process = optax.scale(-1.0 * lr)

        process.append(
            optax.scale_by_adam(b1=self.beta1,
                                b2=self.beta2,
                                eps=self.epsilon,
                                eps_root=0.0))
        process.append(optax.add_decayed_weights(self.weight_decay, None))
        process.append(last_process)
        return optax.chain(*process)
Esempio n. 15
0
    def _create_jax_optimizer(self):
        import optax
        process = []
        if isinstance(self.learning_rate, LearningRateSchedule):
            scheduler = self.learning_rate._create_jax_schedule()
            process.append(optax.scale_by_schedule(scheduler))
            last_process = optax.scale(-1.0)
        else:
            lr = self.learning_rate
            last_process = optax.scale(-1.0 * lr)

        process.append(
            optax.scale_by_rms(decay=self.decay,
                               eps=self.epsilon,
                               initial_scale=0.0))
        if self.momentum is not None or self.momentum != 0.0:
            process.append(optax.trace(decay=self.momentum, nesterov=False))
        process.append(last_process)
        return optax.chain(*process)
Esempio n. 16
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: ppo_networks.PPONetworks,
        dataset: Iterator[reverb.ReplaySample],
        logger_fn: loggers.LoggerFactory,
        environment_spec: specs.EnvironmentSpec,
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:
        del environment_spec, replay_client

        if callable(self._config.learning_rate):
            optimizer = optax.chain(
                optax.clip_by_global_norm(self._config.max_gradient_norm),
                optax.scale_by_adam(eps=self._config.adam_epsilon),
                optax.scale_by_schedule(self._config.learning_rate),
                optax.scale(-1))
        else:
            optimizer = optax.chain(
                optax.clip_by_global_norm(self._config.max_gradient_norm),
                optax.scale_by_adam(eps=self._config.adam_epsilon),
                optax.scale(-self._config.learning_rate))

        return learning.PPOLearner(
            ppo_networks=networks,
            iterator=dataset,
            discount=self._config.discount,
            entropy_cost=self._config.entropy_cost,
            value_cost=self._config.value_cost,
            max_abs_reward=self._config.max_abs_reward,
            ppo_clipping_epsilon=self._config.ppo_clipping_epsilon,
            clip_value=self._config.clip_value,
            gae_lambda=self._config.gae_lambda,
            counter=counter,
            random_key=random_key,
            optimizer=optimizer,
            num_epochs=self._config.num_epochs,
            num_minibatches=self._config.num_minibatches,
            logger=logger_fn('learner'),
        )
Esempio n. 17
0
def make_optimizer(optimizer_config, lr_schedule):
  """Construct the optax optimizer with given LR schedule."""
  if (optimizer_config.get('decay_pos_embs') is None or
      optimizer_config.decay_pos_embs):
    # Decay learned position embeddings by default.
    weight_decay_exclude_names = ['b']
  else:
    weight_decay_exclude_names = ['pos_embs', 'b']

  optax_chain = []
  if optimizer_config.max_norm > 0:
    optax_chain.append(
        optax.clip_by_global_norm(optimizer_config.max_norm))

  if optimizer_config.optimizer == 'adam':
    # See: https://arxiv.org/abs/1412.6980
    optax_chain.extend([
        optax.scale_by_adam(**optimizer_config.adam_kwargs),
        add_weight_decay(
            optimizer_config.weight_decay,
            exclude_names=weight_decay_exclude_names)
    ])
  elif optimizer_config.optimizer == 'lamb':
    # See: https://arxiv.org/abs/1904.00962
    optax_chain.extend([
        optax.scale_by_adam(**optimizer_config.lamb_kwargs),
        add_weight_decay(
            optimizer_config.weight_decay,
            exclude_names=weight_decay_exclude_names),
        optax.scale_by_trust_ratio()
    ])
  else:
    raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}')

  # Scale by the (negative) learning rate.
  optax_chain.extend([
      optax.scale_by_schedule(lr_schedule),
      optax.scale(-1),
  ])

  return optax.chain(*optax_chain)
Esempio n. 18
0
 def build_optimizer(lr,
                     momentum,
                     steps_per_epoch,
                     n_epochs,
                     nesterov,
                     warmup_epochs=5):
     cosine_schedule = optax.cosine_decay_schedule(1,
                                                   decay_steps=n_epochs *
                                                   steps_per_epoch,
                                                   alpha=1e-10)
     warmup_schedule = optax.polynomial_schedule(
         init_value=0.0,
         end_value=1.0,
         power=1,
         transition_steps=warmup_epochs * steps_per_epoch,
     )
     schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x)
                                      )
     optimizer = optax.sgd(lr, momentum, nesterov=nesterov)
     optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule))
     return optimizer
Esempio n. 19
0
def make_optimizer(lr_schedule, momentum_decay):
    return optax.chain(optax.trace(decay=momentum_decay, nesterov=False),
                       optax.scale_by_schedule(lr_schedule), optax.scale(-1))
Esempio n. 20
0
    total_steps = params["total_steps"]

    pe = params["pe"]
    assert pe in ["fixed", "rotary", "t5"]

    warmup_steps = params["warmup_steps"]
    anneal_steps = params["anneal_steps"]
    lr = params["lr"]
    end_lr = params["end_lr"]
    weight_decay = params["weight_decay"]

    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps), clip_by_global_norm(1),
        optax.scale_by_adam(), additive_weight_decay(weight_decay),
        optax.scale(-1),
        optax.scale_by_schedule(
            util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)))

    params["optimizer"] = opt

    start = time.time()
    tpu_size = jax.device_count()
    if tpu_size < cores_per_replica:
        msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})"
        raise ValueError(msg)
    print(f"jax devices: {tpu_size}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (tpu_size // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    # pick initial ckpt - based on tuning vs train from scratch
Esempio n. 21
0
def main(_):
    # Create the dataset.
    tokenizer = utils.init_tokenizer(FLAGS.dataset)
    graph_tokenizer = utils.init_graph_tokenizer()
    dataset_class = utils.get_dataset_class(FLAGS.dataset, FLAGS.model_type)
    has_graph = True if FLAGS.model_type == 'graph2text' else False
    local_devices = jax.local_devices()
    num_gpus = min(FLAGS.num_gpus, len(local_devices))

    if FLAGS.job_mode == 'train':
        train_dataset = dataset_class(tokenizer=tokenizer,
                                      graph_tokenizer=graph_tokenizer,
                                      batch_size=FLAGS.train_batch_size,
                                      subset='train',
                                      timesteps=FLAGS.train_timesteps,
                                      version=FLAGS.graph_data_version,
                                      shuffle_data=True,
                                      repeat=True,
                                      debug=FLAGS.debug)
        train_iter = iter(train_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.train_memory_size)
        optimizer = optax.chain(
            optax.clip_by_global_norm(FLAGS.grad_clip), optax.scale_by_adam(),
            optax.scale_by_schedule(
                functools.partial(utils.schedule,
                                  lr_schedule=FLAGS.lr_schedule,
                                  init_lr=FLAGS.init_lr,
                                  min_lr_ratio=FLAGS.min_lr_ratio,
                                  max_steps=FLAGS.max_steps)), optax.scale(-1))
        optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5)
        updater = Updater(loss_fn,
                          optimizer,
                          devices=local_devices[:num_gpus],
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _train(updater, train_iter, num_gpus)
    elif FLAGS.job_mode == 'eval':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=FLAGS.eval_batch_size,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.eval_timesteps,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=False,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.eval_memory_size)
        # only use one device for evaluation
        devices = local_devices[:1]
        updater = Updater(loss_fn,
                          optimizer=None,
                          devices=devices,
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _eval(updater, eval_iter)
    elif FLAGS.job_mode == 'sample':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=1,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.sample_length,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=True,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        _sample(eval_iter, tokenizer, local_devices[:num_gpus])
    elif FLAGS.job_mode == 'retrieve':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=1,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.eval_timesteps,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=False,
                                     graph_retrieval_dataset=True,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.eval_memory_size)
        # only use one device for evaluation
        devices = local_devices[:1]
        updater = Updater(loss_fn,
                          optimizer=None,
                          devices=devices,
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _retrieve(updater, eval_iter)
Esempio n. 22
0
def make_adam_optimizer(lr_schedule, b1=0.9, b2=0.999, eps=1e-8):
    """Make Adam optimizer."""
    # Maximize log-prob instead of minimizing loss
    return optax.chain(optax.scale_by_adam(b1=b1, b2=b2, eps=eps),
                       optax.scale_by_schedule(lr_schedule))
Esempio n. 23
0
def make_sgd_optimizer(lr_schedule, momentum_decay):
    """Make SGD optimizer with momentum."""
    # Maximize log-prob instead of minimizing loss
    return optax.chain(optax.trace(decay=momentum_decay, nesterov=False),
                       optax.scale_by_schedule(lr_schedule))
Esempio n. 24
0
def _scale_by_learning_rate(learning_rate, flip_sign=True):
    m = -1 if flip_sign else 1
    if callable(learning_rate):
        return optax.scale_by_schedule(lambda count: m * learning_rate(count))
    return optax.scale(m * learning_rate)
Esempio n. 25
0
def main(_):
    # Make the network
    model = hk.without_apply_rng(hk.transform_with_state(forward_fn))

    if FLAGS.spectral_norm > 0:
        sn_fn = hk.transform_with_state(
            lambda x: SNParamsTree(ignore_regex='[^?!.]*b$|[^?!.]*offset$',
                                   val=FLAGS.spectral_norm)(x))

    # Initialisation
    optimizer = optax.chain(optax.adam(learning_rate=FLAGS.learning_rate),
                            optax.scale_by_schedule(lr_schedule))

    rng_seq = hk.PRNGSequence(42)

    if FLAGS.gaussian_prior:
        last_dim = 2
    else:
        last_dim = 1

    params, state = model.init(next(rng_seq),
                               jnp.zeros((1, FLAGS.map_size, FLAGS.map_size,
                                          last_dim)),
                               jnp.zeros((1, 1, 1, 1)),
                               is_training=True)

    opt_state = optimizer.init(params)

    if FLAGS.spectral_norm > 0:
        _, sn_state = sn_fn.init(next(rng_seq), params)
    else:
        sn_state = 0.

    # If the Gaussian prior is used, load the theoretical power spectrum
    pixel_size = jnp.pi * FLAGS.resolution / 180. / 60.  #rad/pixel
    if FLAGS.gaussian_prior:
        ps_data = onp.load(FLAGS.gaussian_path).astype('float32')
        ell = jnp.array(ps_data[0, :])
        # massivenu: channel 4
        ps_halofit = jnp.array(ps_data[1, :] /
                               pixel_size**2)  # normalisation by pixel size
        # convert to pixel units of our simple power spectrum calculator
        kell = ell / 2 / jnp.pi * 360 * pixel_size / FLAGS.map_size
        # Interpolate the Power Spectrum in Fourier Space
        power_map = jnp.array(
            make_power_map(ps_halofit, FLAGS.map_size, kps=kell))

    def score_fn(params, state, batch, is_training=True):
        if FLAGS.gaussian_prior:
            # If requested, first compute the Gaussian prior
            gaussian_score = gaussian_prior_score(batch['y'][..., 0],
                                                  batch['s'][...,
                                                             0], power_map)
            gaussian_score = jnp.expand_dims(gaussian_score, axis=-1)
            net_input = jnp.concatenate(
                [batch['y'],
                 jnp.abs(batch['s'])**2 * gaussian_score], axis=-1)
            res, state = model.apply(params,
                                     state,
                                     net_input,
                                     batch['s'],
                                     is_training=is_training)
        else:
            res, state = model.apply(params,
                                     state,
                                     batch['y'],
                                     batch['s'],
                                     is_training=is_training)
            gaussian_score = jnp.zeros_like(res)
        return batch, res, state, gaussian_score

    # Training loss
    def loss_fn(params, state, rng_key, batch):
        _, res, state, gaussian_score = score_fn(params, state, batch)
        loss = jnp.mean((batch['u'] + batch['s'] * (res + gaussian_score))**2)
        return loss, state

    @jax.jit
    def update(params, state, sn_state, rng_key, opt_state, batch):
        (loss, state), grads = jax.value_and_grad(loss_fn,
                                                  has_aux=True)(params, state,
                                                                rng_key, batch)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        if FLAGS.spectral_norm > 0:
            new_params, new_sn_state = sn_fn.apply(None, sn_state, None,
                                                   new_params)
        else:
            new_sn_state = sn_state
        return loss, new_params, state, new_sn_state, new_opt_state

    train = load_dataset(FLAGS.dataset, FLAGS.batch_size, FLAGS.map_size,
                         FLAGS.noise_dist_std, FLAGS.train_split)

    summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir)

    print('training begins')
    for step in range(FLAGS.training_steps):
        loss, params, state, sn_state, opt_state = update(
            params, state, sn_state, next(rng_seq), opt_state, next(train))

        if step % 50 == 0:
            summary_writer.scalar('train_loss', loss, step)
            print(step, loss)

        if step % 500 == 0:
            # Running denoiser on a batch of images
            batch, res, _, gs = score_fn(params,
                                         state,
                                         next(train),
                                         is_training=False)
            summary_writer.image(
                'score/target',
                onp.clip(batch['x'][0, :, :, 0], 0, 0.1) * 10., step)
            summary_writer.image(
                'score/input',
                onp.clip(batch['y'][0, :, :, 0], 0, 0.1) * 10., step)
            summary_writer.image('score/score',
                                 res[0, :, :, 0] + gs[0, :, :, 0], step)
            summary_writer.image(
                'score/denoised',
                onp.clip(
                    batch['y'][0, :, :, 0] + batch['s'][0, :, :, 0]**2 *
                    (res[0, :, :, 0] + gs[0, :, :, 0]), 0, 0.1) * 10., step)
            summary_writer.image(
                'score/gaussian_denoised',
                onp.clip(
                    batch['y'][0, :, :, 0] +
                    batch['s'][0, :, :, 0]**2 * gs[0, :, :, 0], 0, 0.1) * 10.,
                step)
            print(step)

        if step % 5000 == 0:
            with open(FLAGS.output_dir + '/model-%d.pckl' % step,
                      'wb') as file:
                pickle.dump([params, state, sn_state], file)

    summary_writer.flush()

    with open(FLAGS.output_dir + '/model-final.pckl', 'wb') as file:
        pickle.dump([params, state, sn_state], file)
Esempio n. 26
0
def solve_sdp_dual(verif_instance,
                   key=None,
                   opt=None,
                   num_steps=10000,
                   verbose=False,
                   eval_every=1000,
                   use_exact_eig_eval=True,
                   use_exact_eig_train=False,
                   n_iter_lanczos=30,
                   scl=-1.0,
                   lr_init=1e-3,
                   steps_per_anneal=100,
                   anneal_factor=1.0,
                   num_anneals=3,
                   opt_name='adam',
                   gd_momentum=0.9,
                   add_diagnostic_stats=False,
                   opt_multiplier_fn=None,
                   init_dual_vars=None,
                   init_opt_state=None,
                   opt_dual_vars=None,
                   kappa_reg_weight=None,
                   kappa_zero_after=None,
                   device_type=None,
                   save_best_k=1,
                   include_opt_state=False):
    # pylint: disable=g-doc-return-or-yield, g-doc-args
    """Compute verified lower bound via dual of SDP relaxation.

  NOTE: This method exposes many hyperparameter options, and the method
  signature is subject to change. We instead suggest using
  ``solve_sdp_dual_simple`` instead if you need a stable interface.
  """
    # NB: Whereas the rest of the code in this library is fairly top-down
    # readable, avoids excessive `if` statements, tries to make the code look
    # like the formalism, etc, this is not the case for this method.
    # This is essentially the outer loop, and includes all the debugging/logging/
    # optimization tricks we need to get/debug good results.
    #
    # NB: Time profiling: On toy VerifInstances, JIT compilation dominates time
    # cost: JIT compilation takes ~12s, then we do ~3000 steps/sec.
    assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type'
    assert isinstance(verif_instance,
                      utils.SdpDualVerifInstance), 'invalid type'
    key = key if key is not None else jax.random.PRNGKey(0)

    dual_vars = jax.tree_map(lambda s: None if s is None else jnp.zeros(s),
                             verif_instance.dual_shapes)
    dual_vars = init_duals_ibp(verif_instance, dual_vars)

    if init_dual_vars is not None:
        # Casting, here for Colab. Essentially same as `dual_vars = init_dual_vars`
        dual_vars = utils.structure_like(init_dual_vars, dual_vars)
    if opt_dual_vars is not None:
        opt_dual_vars = utils.structure_like(opt_dual_vars, dual_vars)

    # Create optimizer
    if opt is None:
        if (isinstance(steps_per_anneal, float)
                or isinstance(steps_per_anneal, int)):
            anneal_steps = [
                steps_per_anneal * (i + 1) for i in range(num_anneals)
            ]
        else:
            anneal_steps = np.cumsum(steps_per_anneal)
        anneal_steps = jnp.array(anneal_steps)

        def lr_schedule(t):
            cur_epoch = jnp.minimum(num_anneals, jnp.sum(t > anneal_steps))
            return lr_init * jnp.float_power(anneal_factor, cur_epoch)

        opt_class = getattr(optax, opt_name)
        base_opt = (opt_class(1., momentum=gd_momentum)
                    if opt_name == 'sgd' else opt_class(1.))
        opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule))
        if opt_multiplier_fn:
            # NB: Interface very specific to tree.map_structure_with_path
            # Example: opt_multiplier_fn=lambda path: 0.1 if 'lam' in path else 1.0
            opt_multipliers = tree.map_structure_with_path(
                lambda path, v: opt_multiplier_fn(path), dual_vars)
            opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule),
                              utils.scale_by_variable_opt(opt_multipliers))
        else:
            opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule))

    # Define loss function
    def loss(dual_vars, loss_scl=scl, exact=use_exact_eig_train):
        return _loss(dual_vars, loss_scl, exact)

    @functools.partial(jax.jit, static_argnums=(1, 2), backend=device_type)
    def _loss(dual_var, loss_scl, exact):
        loss_val, step_info = dual_fun(verif_instance,
                                       dual_var,
                                       key,
                                       n_iter=n_iter_lanczos,
                                       exact=exact,
                                       scl=loss_scl,
                                       include_info=True)
        step_info['loss_val'] = loss_val
        return loss_val, step_info

    # Define a compiled update step
    grad = jax.jit(jax.grad(loss, has_aux=True), backend=device_type)

    @functools.partial(jax.jit, backend=device_type)
    def grad_step(params, opt_state):
        g, info = grad(params)
        updates, new_opt_state = opt.update(g, opt_state)
        new_params = optax.apply_updates(params, updates)
        info['g'] = g
        info['updates'] = updates
        return new_params, new_opt_state, info

    # Optimize parameters in a loop
    opt_state = opt.init(dual_vars)
    if init_opt_state:
        opt_state = utils.structure_like(init_opt_state, opt_state)
    info = collections.defaultdict(list)
    loss_log = []
    store_best = []
    recent_eig_vecs = collections.deque(maxlen=10)
    best_loss = 1e9
    last_H = None
    start_i = 0

    # Main loop
    for i in range(start_i, num_steps):
        dual_vars_prev = dual_vars
        dual_vars, opt_state, step_info = grad_step(dual_vars, opt_state)
        loss_val = step_info['loss_val']
        print(f'Iter {i}: Loss {loss_val}')
        best_loss = min(best_loss, loss_val)
        if add_diagnostic_stats:
            info['dual_vars'].append(dual_vars_prev)
            eig_vec = step_info['eig_vec']
            cosine_sims = []
            for prev_eig_vec in recent_eig_vecs:
                denom = jnp.sqrt(
                    jnp.linalg.norm(eig_vec) * jnp.linalg.norm(prev_eig_vec))
                eig_sim = jnp.sum(prev_eig_vec * eig_vec) / denom
                cosine_sims.append(abs(float(eig_sim)))
            info['c_lambda'].append(float(step_info['c_lambda']))
            info['past_10_cosine_sims'].append(np.array(cosine_sims))
            info['g'].append(step_info['g'])
            info['updates'].append(step_info['updates'])
            if use_exact_eig_train:
                # The info is for -H, so to get smallest for H, take negative of max
                eig_vals = -step_info['eig_info'][0][-1:-20:-1]
                cur_H = step_info['eig_info'][2]
                diff_H = 0 if last_H is None else np.linalg.norm(cur_H -
                                                                 last_H)
                last_H = cur_H
                info['diff_H'].append(float(diff_H))
                info['smallest_20_eig_vals'].append(eig_vals)
            recent_eig_vecs.appendleft(eig_vec)

        loss_log.append(loss_val)
        if len(store_best) < save_best_k:
            store_best.append((loss_val, dual_vars_prev))
            store_best.sort(key=lambda x: x[0])
        elif loss_val < store_best[-1][0]:
            store_best[-1] = (loss_val, dual_vars_prev)
            store_best.sort(key=lambda x: x[0])

        # Regularization of kappa
        if kappa_reg_weight is not None and kappa_reg_weight >= 0:
            onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1])
            mask = jnp.ones_like(onehot) - onehot
            dual_vars[-1] -= mask * kappa_reg_weight
        if (kappa_zero_after is not None and kappa_zero_after >= 0
                and i > kappa_zero_after):
            onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1])
            dual_vars[-1] *= onehot

        dual_vars = project_duals(dual_vars, verif_instance.dual_types)

        if opt_dual_vars:
            distance_to_opt = jax.tree_multimap(
                lambda x, y: jnp.linalg.norm(x - y), dual_vars, opt_dual_vars)
            info['distance_to_opt'].append(distance_to_opt)

        if i % eval_every == 0:
            dual_val, _ = loss(dual_vars,
                               loss_scl=-1,
                               exact=use_exact_eig_eval)
            info['steps'].append(i)
            info['loss_vals'].append(float(dual_val))
            if verbose:
                print(f'Dual iter {i}: Train loss: {loss_val} Loss {dual_val}')

    final_loss = float(
        loss(dual_vars, loss_scl=-1, exact=use_exact_eig_eval)[0])
    info['final_dual_vars'] = dual_vars
    info['final_opt_state'] = opt_state
    info['final_loss'] = final_loss
    info['loss_log'] = loss_log
    info['store_best'] = store_best
    if include_opt_state:
        return final_loss, info, opt_state
    else:
        return final_loss, info
    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    params["sampler"] = nucleaus_sample
    opt = optax.chain(optax.scale(1 / gradient_accumulation_steps),
                      clip_by_global_norm(1), optax.scale_by_adam(),
                      optax.additive_weight_decay(0), optax.scale(-1),
                      optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)))

    params["optimizer"] = opt

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
        meta = json.load(f)

    ckpt_step = meta["checkpoints"][-1]
    print(f"using checkpoint {ckpt_step}")
Esempio n. 28
0
def scale_by_learning_rate(learning_rate: ScalarOrSchedule):
    if callable(learning_rate):
        return optax.scale_by_schedule(lambda count: -learning_rate(count))
    return optax.scale(-learning_rate)
Esempio n. 29
0
def create_optimizer(config):
    """Creates the optimizer associated to a config."""
    ops = []

    # Gradient clipping either by norm `gradient_norm_clip` or by absolute value
    # `gradient_value_clip`.
    if "gradient_clip" in config:
        raise ValueError("'gradient_clip' is deprecated, please use "
                         "'gradient_norm_clip'.")
    assert not ("gradient_norm_clip" in config
                and "gradient_value_clip" in config), (
                    "Gradient clipping by norm and by value are exclusive.")

    if "gradient_norm_clip" in config:
        ops.append(optax.clip_by_global_norm(config.gradient_norm_clip))
    if "gradient_value_clip" in config:
        ops.append(optax.clip(config.gradient_value_clip))

    # Define the learning rate schedule.
    schedule_fn = utils.get_optax_schedule_fn(
        warmup_ratio=config.get("warmup_ratio", 0.),
        num_train_steps=config.num_train_steps,
        decay=config.get("learning_rate_step_decay", 1.0),
        decay_at_steps=config.get("learning_rate_decay_at_steps", []),
        cosine_decay_schedule=config.get("cosine_decay", False))

    schedule_ops = [optax.scale_by_schedule(schedule_fn)]

    # Scale some parameters matching a regex by a multiplier. Config field
    # `scaling_by_regex` is a list of pairs (regex: str, multiplier: float).
    scaling_by_regex = config.get("scaling_learning_rate_by_regex", [])
    for regex, multiplier in scaling_by_regex:
        logging.info(
            "Learning rate is scaled by %f for parameters matching '%s'",
            multiplier, regex)
        schedule_ops.append(utils.scale_selected_parameters(regex, multiplier))
    schedule_optimizer = optax.chain(*schedule_ops)

    if config.optimizer.lower() == "adam":
        optimizer = optax.adam(config.learning_rate)
        ops.append(optimizer)
        ops.append(schedule_optimizer)
    elif config.optimizer.lower() == "sgd":
        ops.append(schedule_optimizer)
        optimizer = optax.sgd(config.learning_rate, momentum=config.momentum)
        ops.append(optimizer)
    else:
        raise NotImplementedError("Invalid optimizer: {}".format(
            config.optimizer))

    if "weight_decay" in config and config.weight_decay > 0.:
        ops.append(
            utils.decoupled_weight_decay(decay=config.weight_decay,
                                         step_size_fn=schedule_fn))

    # Freeze parameters that match the given regexes (if any).
    freeze_weights_regexes = config.get("freeze_weights_regex", []) or []
    if isinstance(freeze_weights_regexes, str):
        freeze_weights_regexes = [freeze_weights_regexes]
    for reg in freeze_weights_regexes:
        ops.append(utils.freeze(reg))

    return optax.chain(*ops)
Esempio n. 30
0
 def update(
         self, gradient: Weights, state: GenericGradientState,
         parameters: Optional[Weights]
 ) -> Tuple[Weights, GenericGradientState]:
     return GenericGradientState.wrap(*scale_by_schedule(
         self.step_size_fn).update(gradient, state.data, parameters))