예제 #1
0
def main(_):
    # Create the dataset.
    train_dataset, vocab_size = dataset.load(FLAGS.batch_size,
                                             FLAGS.sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads,
                                  FLAGS.num_layers, FLAGS.dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optix.chain(optix.clip_by_global_norm(FLAGS.grad_clip_value),
                            optix.adam(FLAGS.learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)
    updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
        # We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
        # Using values from state/metrics too often will block the runahead and can
        # cause these overheads to become more prominent.
        if step % LOG_EVERY == 0:
            steps_per_sec = LOG_EVERY / (time.time() - prev_time)
            prev_time = time.time()
            metrics.update({'steps_per_sec': steps_per_sec})
            logging.info({k: float(v) for k, v in metrics.items()})
예제 #2
0
  def test_regularized_training(self):
    """Test that adding regularization penalty to the training loss works."""
    np.random.seed(0)
    # Set up the problem of recovering w given x and
    #   y = x . w + noise
    # with the a priori assumption that w is sparse. There are fewer examples
    # than dimensions (x is a wide matrix), so the problem is underdetermined
    # without the sparsity assumption.
    num_examples, num_dim = 8, 10
    x = np.random.randn(num_examples, num_dim).astype(np.float32)
    true_w = np.zeros((num_dim, 2), np.float32)
    true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0]
    true_w[[3, 5], 1] = [4.0, 5.0]
    y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2)

    # Get the least squares estimate for w. It isn't very accurate.
    least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0]
    least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w)

    # Get a better estimate by solving the L1 regularized problem
    #  argmin_w ||x . w - y||_2^2 + c ||w||_1.
    w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w)
    def model_fun(batch):
      x = batch['x']
      return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x)

    model = hk.transform(model_fun)

    def loss_fun(params, batch):
      """Training loss with L1 regularization penalty term."""
      y_predicted, penalties = model.apply(params, batch)
      return hk_util.l2_loss(y_predicted - batch['y']) + penalties

    batch = {'x': x, 'y': y}
    params = model.init(jax.random.PRNGKey(0), batch)
    optimizer = optix.chain(  # Gradient descent with decreasing learning rate.
        optix.trace(decay=0.0, nesterov=False),
        optix.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i)))
    opt_state = optimizer.init(params)

    @jax.jit
    def train_step(params, opt_state, batch):
      grads = jax.grad(loss_fun)(params, batch)
      updates, opt_state = optimizer.update(grads, opt_state)
      new_params = optix.apply_updates(params, updates)
      return new_params, opt_state

    for _ in range(1000):
      params, opt_state = train_step(params, opt_state, batch)

    l1_w = params['linear']['w']
    l1_w_error = hk_util.l2_loss(l1_w - true_w).item()

    # The L1-regularized estimate is much more accurate.
    self.assertGreater(least_squares_w_error, 4.0)
    self.assertLess(l1_w_error, 1.0)
예제 #3
0
파일: optix_test.py 프로젝트: yxd886/jax
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # experimental/optix.py sgd
        optix_sgd_params = self.init_params
        sgd = optix.sgd(LR, 0.0)
        state_sgd = sgd.init(optix_sgd_params)

        # experimental/optix.py sgd apply every
        optix_sgd_apply_every_params = self.init_params
        sgd_apply_every = optix.chain(optix.apply_every(k=k),
                                      optix.trace(decay=0, nesterov=False),
                                      optix.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optix_sgd_apply_every_params)
        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optix_sgd_params = optix.apply_updates(optix_sgd_params,
                                                   updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update(
                self.per_step_updates, state_sgd_apply_every)
            optix_sgd_apply_every_params = optix.apply_updates(
                optix_sgd_apply_every_params, updates_sgd_apply_every)
            if i % k == k - 1:
                # Check equivalence.
                for x, y in zip(tree_leaves(optix_sgd_apply_every_params),
                                tree_leaves(optix_sgd_params)):
                    np.testing.assert_allclose(x, y, atol=1e-6, rtol=100)
            else:
                # Check updaue is zero.
                for x, y in zip(tree_leaves(updates_sgd_apply_every),
                                tree_leaves(zero_update)):
                    np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
예제 #4
0
    def test_graph_network_learning(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R_key, dr0_key, params_key = random.split(key, 3)

        d, _ = space.free()

        R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype)
        dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype)
        E_gt = vmap(
          lambda R, dr0: \
          np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2))

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(params_key, R[0])

        @jit
        def loss(params, R):
            return np.mean((vmap(energy_fn,
                                 (None, 0))(params, R) - E_gt(R, dr0))**2)

        opt = optix.chain(optix.clip_by_global_norm(1.0), optix.adam(1e-4))

        @jit
        def update(params, opt_state, R):
            updates, opt_state = opt.update(grad(loss)(params, R), opt_state)
            return optix.apply_updates(params, updates), opt_state

        opt_state = opt.init(params)

        l0 = loss(params, R)
        for i in range(4):
            params, opt_state = update(params, opt_state, R)

        assert loss(params, R) < l0 * 0.95
예제 #5
0
def make_optimizer():
    """SGD with nesterov momentum and a custom lr schedule."""
    return optix.chain(
        optix.trace(decay=FLAGS.optimizer_momentum,
                    nesterov=FLAGS.optimizer_use_nesterov),
        optix.scale_by_schedule(lr_schedule), optix.scale(-1))
예제 #6
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )
예제 #7
0
def make_optimizer(lr_schedule, momentum_decay):
  return optix.chain(optix.trace(decay=momentum_decay, nesterov=False),
                     optix.scale_by_schedule(lr_schedule),
                     optix.scale(-1))
예제 #8
0
        # images=train_images,
        # labels=train_labels,
        tasks=tasks,
        num_tasks=num_tasks_per_step,
        num_samples=num_inner_samples,
        shuffle=True,
    )

    outer_loop_sampler = partial(
        random_samples,
        # images=flatten(train_images, 1),
        # labels=flatten(train_labels, 1),
        num_samples=num_outer_samples,
    )

    inner_opt = optix.chain(optix.sgd(args.inner_lr))
    inner_loop_fn = make_inner_loop_fn(loss_acc_fn, inner_opt.update)
    outer_loop_loss_fn = make_outer_loop_loss_fn(loss_acc_fn, inner_opt.init,
                                                 inner_loop_fn)

    rng, rng_net = split(rng)
    (out_shape), params = net_init(rng_net, (-1, size, size, 1))

    rln_params, pln_params = (
        params[:args.num_rln_layers],
        params[args.num_rln_layers:],
    )

    outer_opt_init, outer_opt_update, outer_get_params = optimizers.adam(
        step_size=args.outer_lr)
    outer_opt_state = outer_opt_init((rln_params, pln_params))
예제 #9
0
def make_optimizer(lr_schedule, momentum_decay):
    """Make SGD optimizer with momentum."""
    # Maximize log-prob instead of minimizing loss
    return optix.chain(optix.trace(decay=momentum_decay, nesterov=False),
                       optix.scale_by_schedule(lr_schedule))