Example #1
0
def main(_):
    # Create the dataset.
    train_dataset, vocab_size = dataset.load(FLAGS.batch_size,
                                             FLAGS.sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads,
                                  FLAGS.num_layers, FLAGS.dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optix.chain(optix.clip_by_global_norm(FLAGS.grad_clip_value),
                            optix.adam(FLAGS.learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)
    updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
        # We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
        # Using values from state/metrics too often will block the runahead and can
        # cause these overheads to become more prominent.
        if step % LOG_EVERY == 0:
            steps_per_sec = LOG_EVERY / (time.time() - prev_time)
            prev_time = time.time()
            metrics.update({'steps_per_sec': steps_per_sec})
            logging.info({k: float(v) for k, v in metrics.items()})
Example #2
0
    def test_graph_network_learning(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R_key, dr0_key, params_key = random.split(key, 3)

        d, _ = space.free()

        R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype)
        dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype)
        E_gt = vmap(
          lambda R, dr0: \
          np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2))

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(params_key, R[0])

        @jit
        def loss(params, R):
            return np.mean((vmap(energy_fn,
                                 (None, 0))(params, R) - E_gt(R, dr0))**2)

        opt = optix.chain(optix.clip_by_global_norm(1.0), optix.adam(1e-4))

        @jit
        def update(params, opt_state, R):
            updates, opt_state = opt.update(grad(loss)(params, R), opt_state)
            return optix.apply_updates(params, updates), opt_state

        opt_state = opt.init(params)

        l0 = loss(params, R)
        for i in range(4):
            params, opt_state = update(params, opt_state, R)

        assert loss(params, R) < l0 * 0.95
Example #3
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )