Ejemplo n.º 1
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: d4pg_networks.D4PGNetworks,
        dataset: Iterator[reverb.ReplaySample],
        logger_fn: loggers.LoggerFactory,
        environment_spec: specs.EnvironmentSpec,
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:
        del environment_spec, replay_client

        policy_optimizer = optax.adam(self._config.learning_rate)
        critic_optimizer = optax.adam(self._config.learning_rate)

        if self._config.clipping:
            policy_optimizer = optax.chain(optax.clip_by_global_norm(40.),
                                           policy_optimizer)
            critic_optimizer = optax.chain(optax.clip_by_global_norm(40.),
                                           critic_optimizer)

        # The learner updates the parameters (and initializes them).
        return learning.D4PGLearner(
            policy_network=networks.policy_network,
            critic_network=networks.critic_network,
            random_key=random_key,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=self._config.clipping,
            discount=self._config.discount,
            target_update_period=self._config.target_update_period,
            iterator=dataset,
            counter=counter,
            logger=logger_fn('learner'),
            num_sgd_steps_per_step=self._config.num_sgd_steps_per_step)
Ejemplo n.º 2
0
def make_optimizer(momentum=True, schedule_fn = lambda x:-1e-3):
    """SGD with momentum and a fixed lr."""
    if momentum:
        return optax.chain(
            optax.trace(decay=0.9, nesterov=False),  # momentum
            optax.scale_by_schedule(schedule_fn))
    else:
        return optax.chain(
            optax.scale_by_schedule(schedule_fn))
Ejemplo n.º 3
0
    def test_correctness(self):
        """Testing correctness via independent implementation."""
        def ema(decay, debias=True):
            def init_fn(params):
                del params
                return {'w': jnp.zeros((2, )), 'count': 0}

            def update_fn(updates, state, params=None):
                del params
                state['count'] += 1
                state['w'] = ((1 - decay) * updates['w'] + decay * state['w'])
                if debias:
                    update = {'w': state['w'] / (1 - decay**state['count'])}
                else:
                    update = {'w': state['w']}
                return update, state

            return optax.GradientTransformation(init_fn, update_fn)

        decay = 0.7
        learning_rate = 0.01
        true_ema = optax.chain(ema(decay), optax.scale(-1. * learning_rate))
        ks_ema = transform_chain(['first_moment_ema'], [{
            'decay': decay,
            'debias': True,
        }],
                                 learning_rate=learning_rate)
        targets = _optimizer_loop(true_ema)
        results = _optimizer_loop(ks_ema)

        for target, result in zip(targets, results):
            chex.assert_trees_all_close(target, result)
Ejemplo n.º 4
0
  def make_learner(
      self,
      random_key: networks_lib.PRNGKey,
      networks: impala_networks.IMPALANetworks,
      dataset: Iterator[reverb.ReplaySample],
      logger_fn: loggers.LoggerFactory,
      environment_spec: specs.EnvironmentSpec,
      replay_client: Optional[reverb.Client] = None,
      counter: Optional[counting.Counter] = None,
  ) -> core.Learner:
    del environment_spec, replay_client

    optimizer = optax.chain(
        optax.clip_by_global_norm(self._config.max_gradient_norm),
        optax.adam(
            self._config.learning_rate,
            b1=self._config.adam_momentum_decay,
            b2=self._config.adam_variance_decay,
            eps=self._config.adam_eps,
            eps_root=self._config.adam_eps_root))

    return learning.IMPALALearner(
        networks=networks,
        iterator=dataset,
        optimizer=optimizer,
        random_key=random_key,
        discount=self._config.discount,
        entropy_cost=self._config.entropy_cost,
        baseline_cost=self._config.baseline_cost,
        max_abs_reward=self._config.max_abs_reward,
        counter=counter,
        logger=logger_fn('learner'),
    )
Ejemplo n.º 5
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: td3_networks.TD3Networks,
        dataset: Iterator[reverb.ReplaySample],
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:

        critic_optimizer = optax.adam(self._config.critic_learning_rate)
        twin_critic_optimizer = optax.adam(self._config.critic_learning_rate)
        policy_optimizer = optax.adam(self._config.policy_learning_rate)

        if self._config.policy_gradient_clipping is not None:
            policy_optimizer = optax.chain(
                optax.clip_by_global_norm(
                    self._config.policy_gradient_clipping), policy_optimizer)

        return learning.TD3Learner(
            networks=networks,
            random_key=random_key,
            discount=self._config.discount,
            target_sigma=self._config.target_sigma,
            noise_clip=self._config.noise_clip,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            twin_critic_optimizer=twin_critic_optimizer,
            num_sgd_steps_per_step=self._config.num_sgd_steps_per_step,
            bc_alpha=self._config.bc_alpha,
            iterator=dataset,
            logger=self._logger_fn(),
            counter=counter)
Ejemplo n.º 6
0
    def __init__(self, model: SwarmModel,
                 optimizer: optax.GradientTransformation, loss_scale: float,
                 dataloader: Callable, precision: NetworkPrecision):
        self.model = model
        self.optimizer = optax.chain(optax.scale(1 / loss_scale), optimizer)
        self.dataloader = dataloader
        self.minibatches = 1
        self.loss_scale = loss_scale

        assert ray.is_initialized()  # needs a valid ray cluster to start

        example = self.dataloader()
        self.embedding = EmbeddingLayer.options(max_concurrency=8).remote(
            example["obs"], self.model.vocab, self.model.d_model,
            self.optimizer, precision)
        self.embedding.run.remote()

        x, _ = self.embedding.embed_forward.remote(example["obs"])

        self.proj = ProjLayer.options(max_concurrency=8).remote(
            x, self.model.vocab, self.model.d_model, self.optimizer,
            self.loss_scale, precision)
        self.proj.run.remote()

        self.layers = []
        for i in range(model.rev_layers):
            self.layers.append(
                ReversibleLayer.options(max_concurrency=8).remote(
                    self.model.rev_init, i, x, self.optimizer, precision))

        for l in self.layers:
            l.run.remote()

        self.all_layers = [self.embedding] + self.layers + [self.proj]
Ejemplo n.º 7
0
    def build_optimizer(self,
                        clip=15.0,
                        lr=5e-4,
                        warmup=2000,
                        cosine_decay_steps=None,
                        optimizer_name="adabelief") -> GradientTransformation:
        chain = []
        if optimizer_name == "adabelief":
            chain.append(util.scale_by_belief())
        elif optimizer_name == "adam":
            chain.append(optax.scale_by_adam())
        else:
            assert 0

        # Make sure to use the negative learning rate so that we minimize
        if warmup and warmup > 0:
            warmup_schedule = partial(util.linear_warmup_lr_schedule,
                                      warmup=warmup,
                                      lr_decay=1.0,
                                      lr=-lr)
            chain.append(optax.scale_by_schedule(warmup_schedule))
        else:
            chain.append(optax.scale(-lr))

        if cosine_decay_steps and cosine_decay_steps > 0:
            cosine_lr = optax.cosine_decay_schedule(
                init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1)
            chain.append(optax.scale_by_schedule(cosine_lr))

        if clip and clip > 0:
            chain.append(optax.clip(clip))

        return optax.chain(*chain)
Ejemplo n.º 8
0
def main():
    # Create the dataset.
    train_dataset, vocab_size = load(batch_size, sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers,
                                  dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),
                            optax.adam(learning_rate, b1=0.9, b2=0.99))

    updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
Ejemplo n.º 9
0
def sgd_momentum(learning_rate_fn: optax.Schedule,
                 momentum: float = 0.,
                 nesterov: bool = False) -> optax.GradientTransformation:
  return optax.chain(
      optax.trace(decay=momentum, nesterov=nesterov),
      optax.scale_by_schedule(learning_rate_fn),
      optax.scale(-1.))
Ejemplo n.º 10
0
def make_optimizer():
  """SGD with nesterov momentum and a custom lr schedule."""
  return optax.chain(
      optax.trace(
          decay=FLAGS.optimizer_momentum,
          nesterov=FLAGS.optimizer_use_nesterov),
      optax.scale_by_schedule(lr_schedule), optax.scale(-1))
Ejemplo n.º 11
0
def test_batch_overfit(train_dataset):
    vocab_size, d_model, num_heads, num_layers = 100, 32, 8, 1
    dropout_rate, grad_clip_value, learning_rate = 0.01, 0.25, 2e-2
    max_iter = 100

    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers,
                                  dropout_rate, max_iter)

    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),
                            optax.adam(learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)

    # Initialize parameters.
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    for step in range(100):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)

    assert metrics['loss'] < 0.1
    assert metrics['step'] == 99
Ejemplo n.º 12
0
def optimizer(hyperparameters):
    opt_init_fn, opt_update_fn = optax.chain(
        optax.scale_by_adam(b1=1.0 - hyperparameters.one_minus_beta_1,
                            b2=0.999,
                            eps=hyperparameters.epsilon),
        optax.scale(-hyperparameters.learning_rate))
    return opt_init_fn, opt_update_fn
Ejemplo n.º 13
0
  def test_graph_conditioned_transformer_learns(self):
    graphs = jraph.GraphsTuple(
        nodes=np.ones((4, 3), dtype=np.float32),
        edges=np.ones((3, 1), dtype=np.float32),
        senders=np.array([0, 2, 3], dtype=np.int32),
        receivers=np.array([1, 3, 2], dtype=np.int32),
        n_node=np.array([2, 2], dtype=np.int32),
        n_edge=np.array([1, 2], dtype=np.int32),
        globals=None,
        )
    seqs = np.array([[1, 2, 2, 0],
                     [1, 3, 3, 3]], dtype=np.int32)
    vocab_size = seqs.max() + 1
    embed_dim = 8
    max_graph_size = graphs.n_node.max()

    logging.info('Training seqs: %r', seqs)

    x = seqs[:, :-1]
    y = seqs[:, 1:]

    def model_fn(vocab_size, embed_dim):
      return models.Graph2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=embed_dim,
          num_layers=2,
          num_heads=4,
          cutoffs=[],
          gnn_embed_dim=embed_dim,
          gnn_num_layers=2)

    def forward(graphs, inputs, labels, max_graph_size):
      input_mask = (labels != 0).astype(jnp.float32)
      return model_fn(vocab_size, embed_dim).loss(
          graphs, max_graph_size, False, inputs, labels, mask=input_mask)

    init_fn, apply_fn = hk.transform_with_state(forward)
    rng = hk.PRNGSequence(8)
    params, state = init_fn(next(rng), graphs, x, y, max_graph_size)

    def apply(*args, **kwargs):
      out, state = apply_fn(*args, **kwargs)
      return out[0], (out[1], state)
    apply = jax.jit(apply, static_argnums=6)

    optimizer = optax.chain(
        optax.scale_by_adam(),
        optax.scale(-1e-3))
    opt_state = optimizer.init(params)
    for i in range(500):
      (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)(
          params, state, next(rng), graphs, x, y, max_graph_size)
      metrics, state = model_state
      updates, opt_state = optimizer.update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)
      if (i + 1) % 100 == 0:
        logging.info(
            'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()})
    logging.info('Loss: %.8f', loss)
    self.assertLess(loss, 1.0)
Ejemplo n.º 14
0
def _run_attack(
    max_objective_fn: Callable[[Tensor, PRNGKey], Tensor],
    projection_fn: Callable[[Tensor], Tensor],
    x_init: Tensor,
    prng_key: PRNGKey,
    num_steps: int,
    learning_rate: float,
):
    """Run attack."""

    opt = optax.chain(
        optax.scale(-1),  # maximization
        optax.adam(learning_rate))
    grad_fn = jax.grad(max_objective_fn)

    def body_fn(it, inputs):
        del it  # unused
        x, prng_in, opt_state = inputs
        prng_out, prng_used = jax.random.split(prng_in)
        grad_x = grad_fn(x, prng_used)
        updates, opt_state = opt.update(grad_x, opt_state, x)
        x = optax.apply_updates(x, updates)
        x = projection_fn(x)
        return x, prng_out, opt_state

    opt_state = opt.init(x_init)
    init_state = (x_init, prng_key, opt_state)
    x, prng_final, _ = jax.lax.fori_loop(0, num_steps, body_fn, init_state)

    return max_objective_fn(x, prng_final)
Ejemplo n.º 15
0
def main(_):
    # Create the dataset.
    train_dataset, vocab_size = dataset.load(FLAGS.batch_size,
                                             FLAGS.sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads,
                                  FLAGS.num_layers, FLAGS.dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.grad_clip_value),
                            optax.adam(FLAGS.learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)
    updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
        # We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
        # Using values from state/metrics too often will block the runahead and can
        # cause these overheads to become more prominent.
        if step % LOG_EVERY == 0:
            steps_per_sec = LOG_EVERY / (time.time() - prev_time)
            prev_time = time.time()
            metrics.update({'steps_per_sec': steps_per_sec})
            logging.info({k: float(v) for k, v in metrics.items()})
Ejemplo n.º 16
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks_lib.FeedForwardNetwork,
        config: DQNConfig,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=config.n_step,
            batch_size=config.batch_size,
            max_replay_size=config.max_replay_size,
            min_replay_size=config.min_replay_size,
            priority_exponent=config.priority_exponent,
            discount=config.discount,
        )
        self._server = reverb_replay.server

        optimizer = optax.chain(
            optax.clip_by_global_norm(config.max_gradient_norm),
            optax.adam(config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(
            jax.random.PRNGKey(config.seed))
        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            random_key=key_learner,
            optimizer=optimizer,
            discount=config.discount,
            importance_sampling_exponent=config.importance_sampling_exponent,
            target_update_period=config.target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: networks_lib.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(config.epsilon).sample(
                key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(key_actor),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=max(config.batch_size, config.min_replay_size),
            observations_per_step=config.batch_size /
            config.samples_per_insert,
        )
Ejemplo n.º 17
0
def train(*, data_folder, batch_size, epochs, learning_rate, weight_decay,
          seed, max_norm, text_vocab, text_dim, text_depth, text_heads,
          audio_dim, audio_depth, audio_heads):
    # rng

    rng_key = random.PRNGKey(seed)

    # data

    dataset = PairTextSpectrogramDataset(data_folder)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    collate_fn=pair_text_spectrogram_dataset_collate_fn,
                    drop_last=True,
                    shuffle=True)

    # model

    model = CLAP(text_vocab=text_vocab,
                 text_dim=text_dim,
                 text_depth=text_depth,
                 text_heads=text_heads,
                 audio_dim=audio_dim,
                 audio_depth=audio_depth,
                 audio_heads=audio_heads)

    # optimizer

    exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1,
                                                     params)

    optim = chain(clip_by_global_norm(max_norm), scale_by_adam(eps=1e-4),
                  add_decayed_weights(weight_decay, exclude_bias),
                  scale(-learning_rate))

    # init

    audio, audio_mask, text, text_mask = next(iter(dl))

    params = model.init(rng_key, text, audio, text_mask, audio_mask)
    optim_state = optim.init(params)

    # loss function, for use with value_and_grad

    @jit
    @value_and_grad
    def loss_fn(params, text, audio, text_mask, audio_mask):
        return model.apply(params, text, audio, text_mask, audio_mask)

    # train loop

    for _ in range(epochs):
        for audio, audio_mask, text, text_mask in dl:
            loss, grads = loss_fn(params, text, audio, text_mask, audio_mask)
            updates, optim_state = optim.update(grads, optim_state, params)
            params = apply_updates(params, updates)
            print(f'loss: {loss}')
Ejemplo n.º 18
0
def create_train_state(rng, config: ml_collections.ConfigDict, model):
    """Create initial training state."""
    params = get_initial_params(rng, model)
    tx = optax.chain(
        optax.sgd(learning_rate=config.learning_rate,
                  momentum=config.momentum),
        optax.additive_weight_decay(weight_decay=config.weight_decay))
    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    return state
Ejemplo n.º 19
0
def kitchen_sink(chains: List[optax.GradientTransformation],
                 scales: jnp.array = None,
                 combinator: Union[Callable[[Any, Any], Any], str] = 'sum',
                 combinator_args: Dict[str, float] = None,
                 learning_rate: float = None) -> optax.GradientTransformation:
  """Runs a list of GradientTransforms in parallel and combines.

  Args:
    chains: list of optax.GradientTransforms (typically from transform_chain).
    scales: a (len(chains),)-shaped jnp.array.
    combinator: a combinator that reduces a list of identical pytrees
    combinator_args: a dictionary of keyword arguments to the combinator func.
    learning_rate: learning rate that gets injected.

  Returns:
    optax.GradientTransform
  """
  if isinstance(combinator, str):
    combinator = _combinators.get(combinator, _sum_combinator)
  combinator_args = combinator_args or {}

  if scales is None:
    scales = jnp.ones(len(chains))

  chains = [
      optax.chain(chain, optax.scale(scale))
      for chain, scale in zip(chains, scales)
  ]

  def init_fn(params):
    return [chain.init(params) for chain in chains]

  def update_fn(updates, state, params=None):
    result = [chain.update(updates, chain_state, params)
              for chain, chain_state in zip(chains, state)]
    new_updates, new_state = list(zip(*result))
    return combinator(*new_updates, **combinator_args), new_state

  transform = optax.GradientTransformation(init_fn, update_fn)

  if learning_rate is not None:
    transform = optax.chain(transform, scale_by_learning_rate(learning_rate))

  return transform
Ejemplo n.º 20
0
    def __init__(
        self,
        *optimizer: optax.GradientTransformation,
        lr_schedule: tp.Optional[LRScheduler] = None,
        steps_per_epoch: tp.Union[int, jnp.ndarray, np.ndarray, None] = None,
        **kwargs,
    ):
        r"""
        Arguments:
            optimizer: An optax `GradientTransformation` object, if more than one is passed via `*args` then they are
                grouped using `optax.chain`.
            lr_schedule: A optional callable of the form `def lr_schedule(step: int, epoch: Optional[int]) -> float` that
                returns the learning rate schedule at each time step. If `steps_per_epoch` is given then epoch is calculated,
                else epoch is None.
            steps_per_epoch: The number of steps to in an epoch, needed to caculate `epoch` from `step`.
        """

        if len(optimizer) == 0:
            raise ValueError("Must pass atleast 1 optimizer, got 0")

        elif lr_schedule is not None:
            # do this to preserve reference after re-assign latter
            base_schedule = lr_schedule

            def lr_schedule_(step: jnp.ndarray) -> jnp.ndarray:
                epoch: tp.Any = (step // steps_per_epoch
                                 if steps_per_epoch is not None else None)

                return base_schedule(step, epoch)

            optimizer = optax.chain(
                *optimizer,
                optax.scale_by_schedule(lr_schedule_),
            )

            lr_schedule = lr_schedule_

        elif len(optimizer) == 1:
            optimizer = optimizer[0]
        else:
            optimizer = optax.chain(*optimizer)

        self.optimizer = optimizer
        self.lr_schedule = lr_schedule
Ejemplo n.º 21
0
  def test_regularized_training(self):
    """Test that adding regularization penalty to the training loss works."""
    np.random.seed(0)
    # Set up the problem of recovering w given x and
    #   y = x . w + noise
    # with the a priori assumption that w is sparse. There are fewer examples
    # than dimensions (x is a wide matrix), so the problem is underdetermined
    # without the sparsity assumption.
    num_examples, num_dim = 8, 10
    x = np.random.randn(num_examples, num_dim).astype(np.float32)
    true_w = np.zeros((num_dim, 2), np.float32)
    true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0]
    true_w[[3, 5], 1] = [4.0, 5.0]
    y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2)

    # Get the least squares estimate for w. It isn't very accurate.
    least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0]
    least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w)

    # Get a better estimate by solving the L1 regularized problem
    #  argmin_w ||x . w - y||_2^2 + c ||w||_1.
    w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w)
    def model_fun(batch):
      x = batch['x']
      return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x)

    model = hk_util.transform(model_fun)

    def loss_fun(params, batch):
      """Training loss with L1 regularization penalty term."""
      y_predicted, penalties = model.apply(params, None, batch)
      return hk_util.l2_loss(y_predicted - batch['y']) + penalties

    batch = {'x': x, 'y': y}
    params = model.init(jax.random.PRNGKey(0), batch)
    optimizer = optax.chain(  # Gradient descent with decreasing learning rate.
        optax.trace(decay=0.0, nesterov=False),
        optax.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i)))
    opt_state = optimizer.init(params)

    @jax.jit
    def train_step(params, opt_state, batch):
      grads = jax.grad(loss_fun)(params, batch)
      updates, opt_state = optimizer.update(grads, opt_state)
      new_params = optax.apply_updates(params, updates)
      return new_params, opt_state

    for _ in range(1000):
      params, opt_state = train_step(params, opt_state, batch)

    l1_w = params['linear']['w']
    l1_w_error = hk_util.l2_loss(l1_w - true_w).item()

    # The L1-regularized estimate is much more accurate.
    self.assertGreater(least_squares_w_error, 4.0)
    self.assertLess(l1_w_error, 1.0)
Ejemplo n.º 22
0
def transform_chain(
    elements: List[str],
    hps: List[Dict[str, float]] = None,
    masks: List[Any] = None,
    learning_rate: float = None) -> optax.GradientTransformation:
  """Utility function for chaining GradientTransforms based on string names.

  Args:
    elements: list of transform strings.
    hps: list of dicts of args for each transform.
    masks: list of masks for each transform.
    learning_rate: learning rate that gets injected.

  Returns:
    optax.GradientTransform
  """

  hps = hps or [{}] * len(elements)
  masks = masks or [None] * len(elements)
  transforms = []

  if len(hps) != len(elements):
    raise ValueError('Number of hps must equal number of elements.')

  if len(masks) != len(elements):
    raise ValueError('Number of masks must equal number of elements.')

  transforms = [_transformations[el](**hp) for el, hp in zip(elements, hps)]

  for i, (transform, mask) in enumerate(zip(transforms, masks)):
    if mask is not None:
      transforms[i] = optax.masked(transform, mask)

  if learning_rate is not None:
    transforms += [scale_by_learning_rate(learning_rate)]

  init_fn, update_fn = optax.chain(*transforms)

  # NOTE(dsuo): We use plain dicts internally due to this issue
  # https://github.com/deepmind/optax/issues/160.
  def wrapped_init_fn(params):
    return init_fn(flax.core.unfreeze(params))

  def wrapped_update_fn(updates, state, params=None):
    new_updates, state = update_fn(
        flax.core.unfreeze(updates), state,
        None if params is None else flax.core.unfreeze(params))

    if isinstance(updates, flax.core.FrozenDict):
      new_updates = flax.core.freeze(new_updates)

    return new_updates, state

  return optax.GradientTransformation(wrapped_init_fn, wrapped_update_fn)
Ejemplo n.º 23
0
    def get(self) -> optax.GradientTransformation:
        if "adam" in self.optimizer:
            opt = optax.adam(self.base_learning_rate)
        elif "sgd" == self.optimizer and self.lr_schedule == "linear":
            lr_schedule = warm_up_polynomial_schedule(
                base_learning_rate=self.base_learning_rate,
                end_learning_rate=self.final_decay_factor *
                self.base_learning_rate,
                decay_steps=(self.n_batches *
                             (self.epochs - self.lr_warmup_epochs)),
                warmup_steps=self.n_batches * self.lr_warmup_epochs,
                decay_power=1.0,
            )
            momentum = 1 - self.one_minus_momentum
            opt = optax.chain(
                optax.trace(decay=momentum, nesterov=True),
                optax.scale_by_schedule(lr_schedule),
                optax.scale(-1),
            )
        elif "sgd" in self.optimizer and self.lr_schedule == "step":
            lr_decay_epochs = [
                (int(start_epoch_str) * self.epochs) // DEFAULT_NUM_EPOCHS
                for start_epoch_str in self.lr_decay_epochs
            ]
            lr_schedule = warm_up_piecewise_constant_schedule(
                steps_per_epoch=self.n_batches,
                base_learning_rate=self.base_learning_rate,
                decay_ratio=self.lr_decay_ratio,
                decay_epochs=lr_decay_epochs,
                warmup_epochs=self.lr_warmup_epochs,
            )

            momentum = 1 - self.one_minus_momentum
            opt = optax.chain(
                optax.trace(decay=momentum, nesterov=True),
                optax.scale_by_schedule(lr_schedule),
                optax.scale(-1),
            )
        else:
            raise ValueError("No optimizer specified.")
        return opt
Ejemplo n.º 24
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: ppo_networks.PPONetworks,
        dataset: Iterator[reverb.ReplaySample],
        logger_fn: loggers.LoggerFactory,
        environment_spec: specs.EnvironmentSpec,
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:
        del environment_spec, replay_client

        if callable(self._config.learning_rate):
            optimizer = optax.chain(
                optax.clip_by_global_norm(self._config.max_gradient_norm),
                optax.scale_by_adam(eps=self._config.adam_epsilon),
                optax.scale_by_schedule(self._config.learning_rate),
                optax.scale(-1))
        else:
            optimizer = optax.chain(
                optax.clip_by_global_norm(self._config.max_gradient_norm),
                optax.scale_by_adam(eps=self._config.adam_epsilon),
                optax.scale(-self._config.learning_rate))

        return learning.PPOLearner(
            ppo_networks=networks,
            iterator=dataset,
            discount=self._config.discount,
            entropy_cost=self._config.entropy_cost,
            value_cost=self._config.value_cost,
            max_abs_reward=self._config.max_abs_reward,
            ppo_clipping_epsilon=self._config.ppo_clipping_epsilon,
            clip_value=self._config.clip_value,
            gae_lambda=self._config.gae_lambda,
            counter=counter,
            random_key=random_key,
            optimizer=optimizer,
            num_epochs=self._config.num_epochs,
            num_minibatches=self._config.num_minibatches,
            logger=logger_fn('learner'),
        )
Ejemplo n.º 25
0
 def _create_jax_optimizer(self):
     import optax
     process = []
     if isinstance(self.learning_rate, LearningRateSchedule):
         scheduler = self.learning_rate._create_jax_schedule()
         process.append(optax.scale_by_schedule(scheduler))
         last_process = optax.scale(-1.0)
     else:
         lr = self.learning_rate
         last_process = optax.scale(-1.0 * lr)
     process.append(last_process)
     return optax.chain(*process)
Ejemplo n.º 26
0
def main(argv):
    del argv

    learning_rate = 1e-2
    batch_size = 64
    input_size = 8
    n_training_steps = 100

    # Random number generator sequence.
    key_seq = hk.PRNGSequence(1729)

    # A simple Linear function.
    def forward_pass(x):
        return hk.Linear(10)(x)

    network = hk.without_apply_rng(hk.transform(forward_pass))

    # Some arbitrary loss.
    def mean_square_loss(params, x):
        output = network.apply(params, x)
        loss = jnp.sum(output**2)
        return loss

    # Construct a simple Adam optimiser using the transforms in optax.
    # You could also just use the `optax.adam` alias, but we show here how
    # to do so manually so that you may construct your own `custom` optimiser.
    opt_init, opt_update = optax.chain(
        # Set the parameters of Adam. Note the learning_rate is not here.
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        # Put a minus sign to *minimise* the loss.
        optax.scale(-learning_rate))

    # Initialise the model's parameters and the optimiser's state.
    # The `state` of an optimiser contains all statistics used by the
    # stateful transformations in the `chain` (in this case just `scale_by_adam`).
    params = network.init(next(key_seq), jnp.zeros([1, input_size]))
    opt_state = opt_init(params)

    # Minimise the loss.
    for step in range(n_training_steps):
        # Get input. Learn to minimize the input to 0.
        data = jax.random.normal(next(key_seq), [batch_size, input_size])
        # Compute gradient and loss.
        loss, grad = jax.value_and_grad(mean_square_loss)(params, data)
        print(f'Loss[{step}] = {loss}')
        # Transform the gradients using the optimiser.
        updates, opt_state = opt_update(grad, opt_state, params)
        # Update parameters.
        params = optax.apply_updates(params, updates)
Ejemplo n.º 27
0
def create_train_state(config, rng, learning_rate_fn, example_batch):
    """Create and initialize the model.

  Args:
    config: Configuration for model.
    rng: JAX PRNG Key.
    learning_rate_fn: learning rate function
    example_batch: for model intialization

  Returns:
    The initialized TrainState with the optimizer.
  """
    model, variables = create_model(config, rng, example_batch)
    params = variables['params']
    parameter_overview.log_parameter_overview(params)

    optimizer = optax.adamw(learning_rate=learning_rate_fn,
                            b1=0.9,
                            b2=.98,
                            eps=1e-9,
                            weight_decay=config.train.weight_decay)

    if config.train.grad_max_norm > 0:
        tx = optax.chain(optax.clip_by_global_norm(config.train.grad_max_norm),
                         optimizer)
    elif config.train.grad_max_val > 1:
        tx = optax.chain(optax.clip(config.train.grad_max_val), optimizer)
    else:
        tx = optimizer

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables['params'],
        tx=tx,
    )
    return model, state
Ejemplo n.º 28
0
    def optimizer(self, learning_rate):
        """Construct optimizer."""
        clip = optax.clip_by_global_norm(
            self.config.optimizer.gradient_clip_norm)
        optimizer = getattr(optax, self.config.optimizer.name)(
            learning_rate,
            **self.config.optimizer.args,
        )
        optim_step = optax.chain(clip, optimizer)
        optim_step = optimizers.maybe_skip_gradient_update(
            optim_step,
            self.config.optimizer.gradient_skip_norm,
        )

        return optim_step
Ejemplo n.º 29
0
    def _create_jax_optimizer(self):
        import optax
        process = []
        if isinstance(self.learning_rate, LearningRateSchedule):
            lr = self.learning_rate.initial_rate
            last_process = optax.scale(-1.0)
        else:
            lr = self.learning_rate
            last_process = optax.scale(-1.0 * lr)

        process.append(
            optax.scale_by_rss(
                initial_accumulator_value=self.initial_accumulator_value,
                eps=self.epsilon))
        process.append(last_process)
        return optax.chain(*process)
Ejemplo n.º 30
0
  def test_bow_transformer_learns(self):
    bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1],
                    [0, 1, 0, 0, 1, 0, 1, 0],
                    [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32)
    seqs = np.array([[1, 2, 2, 3, 0, 0],
                     [1, 2, 4, 5, 6, 0],
                     [1, 3, 3, 5, 4, 2]], dtype=np.int32)
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    vocab_size = seqs.max() + 1

    def model_fn():
      return models.Bow2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=16,
          num_layers=2,
          num_heads=4,
          cutoffs=[])

    def loss_fn(bow, inputs, labels):
      mask = (labels != 0).astype(jnp.float32)
      return model_fn().loss(bow, inputs, labels, mask=mask)

    init_fn, apply_fn = hk.transform_with_state(loss_fn)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), bow, x, y)

    def apply(*args, **kwargs):
      out, state = apply_fn(*args, **kwargs)
      return out[0], (out[1], state)
    value_and_grad = jax.jit(jax.value_and_grad(apply, has_aux=True))

    optimizer = optax.chain(
        optax.scale_by_adam(),
        optax.scale(-1e-3))
    opt_state = optimizer.init(params)
    for i in range(800):
      (loss, model_state), grad = value_and_grad(
          params, state, next(key), bow, x, y)
      metrics, state = model_state
      updates, opt_state = optimizer.update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)
      if (i + 1) % 100 == 0:
        logging.info('Step %d, %r', i + 1,
                     {k: float(v) for k, v in metrics.items()})
    logging.info('Loss: %.8f', loss)
    self.assertLess(loss, 0.1)