Esempio n. 1
0
def _scale_by_learning_rate(learning_rate, flip_sign=True):
    m = -1 if flip_sign else 1
    if callable(learning_rate):
        return optax.scale_by_schedule(lambda count: m * learning_rate(count))
    return optax.scale(m * learning_rate)
Esempio n. 2
0
    def __init__(self,
                 player_id,
                 state_representation_size,
                 num_actions,
                 hidden_layers_sizes=128,
                 replay_buffer_capacity=10000,
                 batch_size=128,
                 replay_buffer_class=ReplayBuffer,
                 learning_rate=0.01,
                 update_target_network_every=1000,
                 learn_every=10,
                 discount_factor=1.0,
                 min_buffer_size_to_learn=1000,
                 epsilon_start=1.0,
                 epsilon_end=0.1,
                 epsilon_decay_duration=int(1e6),
                 optimizer_str="sgd",
                 loss_str="mse",
                 huber_loss_parameter=1.0):
        """Initialize the DQN agent."""

        # This call to locals() is used to store every argument used to initialize
        # the class instance, so it can be copied with no hyperparameter change.
        self._kwargs = locals()

        self.player_id = player_id
        self._num_actions = num_actions
        if isinstance(hidden_layers_sizes, int):
            hidden_layers_sizes = [hidden_layers_sizes]
        self._layer_sizes = hidden_layers_sizes
        self._batch_size = batch_size
        self._update_target_network_every = update_target_network_every
        self._learn_every = learn_every
        self._min_buffer_size_to_learn = min_buffer_size_to_learn
        self._discount_factor = discount_factor
        self.huber_loss_parameter = huber_loss_parameter

        self._epsilon_start = epsilon_start
        self._epsilon_end = epsilon_end
        self._epsilon_decay_duration = epsilon_decay_duration

        # TODO(author6) Allow for optional replay buffer config.
        if not isinstance(replay_buffer_capacity, int):
            raise ValueError("Replay buffer capacity not an integer.")
        self._replay_buffer = replay_buffer_class(replay_buffer_capacity)
        self._prev_timestep = None
        self._prev_action = None

        # Step counter to keep track of learning, eps decay and target network.
        self._step_counter = 0

        # Keep track of the last training loss achieved in an update step.
        self._last_loss_value = None

        # Create the Q-network instances

        def network(x):
            mlp = hk.nets.MLP(self._layer_sizes + [num_actions])
            return mlp(x)

        self.hk_network = hk.without_apply_rng(hk.transform(network))
        self.hk_network_apply = jax.jit(self.hk_network.apply)

        rng = jax.random.PRNGKey(42)
        x = jnp.ones([1, state_representation_size])
        self.params_q_network = self.hk_network.init(rng, x)
        self.params_target_q_network = self.hk_network.init(rng, x)

        if loss_str == "mse":
            self.loss_func = lambda x: jnp.mean(x**2)
        elif loss_str == "huber":
            # pylint: disable=g-long-lambda
            self.loss_func = lambda x: jnp.mean(
                rlax.huber_loss(x, self.huber_loss_parameter))
        else:
            raise ValueError("Not implemented, choose from 'mse', 'huber'.")
        if optimizer_str == "adam":
            opt_init, opt_update = optax.chain(
                optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
                optax.scale(learning_rate))
        elif optimizer_str == "sgd":
            opt_init, opt_update = optax.sgd(learning_rate)
        else:
            raise ValueError("Not implemented, choose from 'adam' and 'sgd'.")
        self._opt_update_fn = self._get_update_func(opt_update)
        self._opt_state = opt_init(self.params_q_network)
        self._loss_and_grad = jax.value_and_grad(self._loss, has_aux=False)
        self._jit_update = jax.jit(self.get_update())
Esempio n. 3
0
def scale_by_learning_rate(learning_rate: ScalarOrSchedule):
    if callable(learning_rate):
        return optax.scale_by_schedule(lambda count: -learning_rate(count))
    return optax.scale(-learning_rate)
Esempio n. 4
0
    X_train_s = scaler.transform(X_train)
    X_test_s = scaler.transform(X_test)

    X_train_s = torch.tensor(X_train_s, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.float32)

    train_dataloader = DataLoader(TensorDataset(X_train_s, y_train),
                                  batch_size=batch_size,
                                  shuffle=True)

    learning_rate = 0.001
    optimizer = optax.chain(
        # Set the parameters of Adam. Note the learning_rate is not here.
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        # Put a minus sign to *minimise* the loss.
        optax.scale(-learning_rate),
    )

    logit_d = Classifier(num_layers=3, hidden_dim=128, use_residual=True)
    params, opt_state = init_fn(X_train_s.shape, jax.random.PRNGKey(seed),
                                logit_d, optimizer)

    train_step = get_train_step(loss, optimizer)

    print("Test accuracy: {:.3f}".format(
        jax.numpy.mean((jax.nn.sigmoid(
            logit_d.apply({"params": params}, np.array(X_test_s))) > 0.5
                        ).flatten() == y_test)))

    iterator = tqdm(range(nsteps))
    for _ in iterator:
Esempio n. 5
0
 def update(
         self, gradient: Weights, state: GenericGradientState,
         parameters: Optional[Weights]
 ) -> Tuple[Weights, GenericGradientState]:
     return GenericGradientState.wrap(*scale(
         **asdict(self)).update(gradient, state.data, parameters))
Esempio n. 6
0
    cores_per_replica = params["cores_per_replica"]

    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    params["sampler"] = nucleaus_sample
    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps),
        clip_by_global_norm(1),
        optax.scale_by_adam(),
        optax.additive_weight_decay(0),
        optax.scale(-1),
        optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
    )

    params["optimizer"] = opt

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)
Esempio n. 7
0
def main(unused_argv):
    validate_flags()
    logging.info(get_config())

    seed = FLAGS.seed if FLAGS.seed is not None else np.random.randint(
        0, 0x7fffffff)
    rng = np.random.RandomState(seed)
    key_seq = hk.PRNGSequence(rng.randint(0, 0x7fffffff))

    activation = jnp.tanh

    eta = 0.3
    assert eta < 4 / (5 + FLAGS.p)  # needed for IGS stability

    if FLAGS.dagger:
        intermediate_policy = training_utils.dagger_policy_with_expert
        final_policy = training_utils.dagger_final_policy
    else:
        intermediate_policy = training_utils.mixed_policy_with_expert
        final_policy = training_utils.final_policy

    # make dynamics and expert
    dynamics, expert_policy = problem_instance_utils.make_dynamics_and_expert(
        next(key_seq), FLAGS.state_dim, FLAGS.p, eta, activation)

    policy_net = training_utils.make_policy_net(64, FLAGS.state_dim,
                                                activation)

    opt_init, opt_update = optax.chain(
        # Set the parameters of Adam. Note the learning_rate is not here.
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        # Put a minus sign to *minimise* the loss.
        optax.scale(-FLAGS.learning_rate))

    aggregate_states, aggregate_actions = [], []
    policy_params = []  # accumulate all the networks trained at each epoch
    for shift_epoch in tqdm.trange(FLAGS.n_shift_epochs):
        logging.info('starting epoch %d', shift_epoch)

        start_time = time.time()
        if shift_epoch == 0:
            epoch_rollout_policy = expert_policy
        else:
            epoch_rollout_policy = functools.partial(intermediate_policy,
                                                     policy_net, expert_policy,
                                                     policy_params,
                                                     FLAGS.alpha)

        x0s_epoch = problem_instance_utils.sample_initial_conditions(
            next(key_seq), FLAGS.n_trajs_per_epoch, FLAGS.state_dim)
        Xs_epoch, Us_epoch = jax.vmap(problem_instance_utils.rollout_policy,
                                      in_axes=(None, None, 0,
                                               None))(dynamics,
                                                      epoch_rollout_policy,
                                                      x0s_epoch, FLAGS.horizon)
        Us_expert_labels = jax.vmap(lambda traj: jax.vmap(expert_policy)
                                    (traj[:-1]))(Xs_epoch)

        logging.info('rolling out %d trajectories took %f seconds',
                     FLAGS.n_trajs_per_epoch,
                     time.time() - start_time)

        # compute goal error
        logging.info('goal error: %s',
                     stats(np.linalg.norm(Xs_epoch[:, -1, :], axis=1)))
        # compute imitation error
        logging.info(
            'imitiation error: %s',
            stats(
                np.sum(np.linalg.norm(Us_epoch - Us_expert_labels, axis=2),
                       axis=1)))

        # format for training
        epoch_train_states = Xs_epoch[:, :-1, :].reshape(
            (-1, Xs_epoch.shape[-1]))
        epoch_train_actions = Us_expert_labels.reshape(
            (-1, Us_expert_labels.shape[-1]))

        # aggregate the accumulated data
        if FLAGS.aggregate_data:
            aggregate_states.append(epoch_train_states)
            aggregate_actions.append(epoch_train_actions)
            epoch_train_states = np.concatenate(aggregate_states, axis=0)
            epoch_train_actions = np.concatenate(aggregate_actions, axis=0)

        logging.info('epoch_train_states.shape: %s', epoch_train_states.shape)
        logging.info('epoch_train_actions.shape: %s',
                     epoch_train_actions.shape)
        assert epoch_train_states.shape[0] == epoch_train_actions.shape[0]
        assert epoch_train_states.shape[1] == FLAGS.state_dim
        assert epoch_train_actions.shape[1] == FLAGS.state_dim

        # initial parameters for training
        if shift_epoch == 0:
            params = policy_net.init(next(key_seq), epoch_train_states[0])
            trust_region_params = jax_utils.pytree_zeros_like(params)
        else:
            assert len(policy_params) >= 1
            params = policy_params[-1]
            trust_region_params = params

        if FLAGS.igs_constraint_lam > 0.0:
            if shift_epoch == FLAGS.n_shift_epochs - 1:

                def policy_fn(policy_network, this_policy_params, x):
                    return final_policy(policy_network,
                                        policy_params + [this_policy_params],
                                        FLAGS.alpha, x)
            else:

                def policy_fn(policy_network, this_policy_params, x):
                    return intermediate_policy(
                        policy_network, expert_policy,
                        policy_params + [this_policy_params], FLAGS.alpha, x)

            def igs_loss(x, y, fx, fy):
                # want |fx - fy| - |x - y| <= 0
                ineq = jnp.abs(fx - fy) - jnp.abs(x - y)
                return FLAGS.igs_constraint_lam * jnp.maximum(ineq, 0)

            igs_constraint_args = (dynamics, igs_loss, policy_fn)
        else:
            igs_constraint_args = None

        start_time = time.time()
        params, _, last_epoch_losses = training_utils.train_policy_network(
            policy_net,
            opt_update, epoch_train_states, epoch_train_actions, params,
            opt_init(params), trust_region_params, 0.0, igs_constraint_args,
            FLAGS.n_train_epochs, FLAGS.batch_size, 0.0, 1000, rng,
            FLAGS.verbose_learner)
        policy_params.append(params)
        logging.info(
            'shift_epoch=%d, last_epoch_losses=%s, '
            'avg_last_epoch_losses=%s', shift_epoch, last_epoch_losses,
            last_epoch_losses / len(epoch_train_states))
        logging.info('train_policy_network at epoch %d took %f seconds',
                     shift_epoch,
                     time.time() - start_time)

    logging.info('running final episodes')

    x0s_final_test = problem_instance_utils.sample_initial_conditions(
        next(key_seq), FLAGS.n_trajs_final_eval, FLAGS.state_dim)
    Xs_final_test_shift, Us_final_test_shift = jax.vmap(
        problem_instance_utils.rollout_policy,
        in_axes=(None, None, 0,
                 None))(dynamics,
                        functools.partial(final_policy, policy_net,
                                          policy_params, FLAGS.alpha),
                        x0s_final_test, FLAGS.horizon)
    Us_expert_final_test_shift = jax.vmap(lambda traj: jax.vmap(expert_policy)
                                          (traj[:-1]))(Xs_final_test_shift)

    Xs_final_test_exp, _ = jax.vmap(problem_instance_utils.rollout_policy,
                                    in_axes=(None, None, 0,
                                             None))(dynamics, expert_policy,
                                                    x0s_final_test,
                                                    FLAGS.horizon)

    final_test_shift = np.linalg.norm(Xs_final_test_shift[:, -1, :], axis=1)
    final_test_exp = np.linalg.norm(Xs_final_test_exp[:, -1, :], axis=1)
    final_test_delta_goal_error = np.linalg.norm(
        Xs_final_test_shift[:, -1, :] - Xs_final_test_exp[:, -1, :], axis=1)
    final_imitation_error = np.sum(np.linalg.norm(Us_final_test_shift -
                                                  Us_expert_final_test_shift,
                                                  axis=2),
                                   axis=1)

    logging.info('final shift goal error: %s', stats(final_test_shift))
    logging.info('expert goal error: %s', stats(final_test_exp))
    logging.info('final delta goal error: %s',
                 stats(final_test_delta_goal_error))
    logging.info('final_imitation_error: %s', stats(final_imitation_error))

    if FLAGS.metrics_outfile is not None:
        with open(FLAGS.metrics_outfile, 'wb') as fp:
            pickle.dump(
                {
                    'final_test_shift': final_test_shift,
                    'final_test_exp': final_test_exp,
                    'final_test_delta_goal_error': final_test_delta_goal_error,
                    'final_imitation_error': final_imitation_error,
                }, fp)
    if FLAGS.config_outfile is not None:
        with open(FLAGS.config_outfile, 'wb') as fp:
            pickle.dump(get_config(), fp)
    if FLAGS.params_outfile is not None:
        with open(FLAGS.params_outfile, 'wb') as fp:
            pickle.dump(
                {
                    'mixing_weight': FLAGS.alpha,
                    'dagger': False,
                    'policy_params': policy_params
                }, fp)
Esempio n. 8
0
def make_optimizer(lr_schedule, momentum_decay):
    return optax.chain(optax.trace(decay=momentum_decay, nesterov=False),
                       optax.scale_by_schedule(lr_schedule), optax.scale(-1))
Esempio n. 9
0
def main(_):
    # Create the dataset.
    tokenizer = utils.init_tokenizer(FLAGS.dataset)
    graph_tokenizer = utils.init_graph_tokenizer()
    dataset_class = utils.get_dataset_class(FLAGS.dataset, FLAGS.model_type)
    has_graph = True if FLAGS.model_type == 'graph2text' else False
    local_devices = jax.local_devices()
    num_gpus = min(FLAGS.num_gpus, len(local_devices))

    if FLAGS.job_mode == 'train':
        train_dataset = dataset_class(tokenizer=tokenizer,
                                      graph_tokenizer=graph_tokenizer,
                                      batch_size=FLAGS.train_batch_size,
                                      subset='train',
                                      timesteps=FLAGS.train_timesteps,
                                      version=FLAGS.graph_data_version,
                                      shuffle_data=True,
                                      repeat=True,
                                      debug=FLAGS.debug)
        train_iter = iter(train_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.train_memory_size)
        optimizer = optax.chain(
            optax.clip_by_global_norm(FLAGS.grad_clip), optax.scale_by_adam(),
            optax.scale_by_schedule(
                functools.partial(utils.schedule,
                                  lr_schedule=FLAGS.lr_schedule,
                                  init_lr=FLAGS.init_lr,
                                  min_lr_ratio=FLAGS.min_lr_ratio,
                                  max_steps=FLAGS.max_steps)), optax.scale(-1))
        optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5)
        updater = Updater(loss_fn,
                          optimizer,
                          devices=local_devices[:num_gpus],
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _train(updater, train_iter, num_gpus)
    elif FLAGS.job_mode == 'eval':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=FLAGS.eval_batch_size,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.eval_timesteps,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=False,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.eval_memory_size)
        # only use one device for evaluation
        devices = local_devices[:1]
        updater = Updater(loss_fn,
                          optimizer=None,
                          devices=devices,
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _eval(updater, eval_iter)
    elif FLAGS.job_mode == 'sample':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=1,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.sample_length,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=True,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        _sample(eval_iter, tokenizer, local_devices[:num_gpus])
    elif FLAGS.job_mode == 'retrieve':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=1,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.eval_timesteps,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=False,
                                     graph_retrieval_dataset=True,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.eval_memory_size)
        # only use one device for evaluation
        devices = local_devices[:1]
        updater = Updater(loss_fn,
                          optimizer=None,
                          devices=devices,
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _retrieve(updater, eval_iter)
    return args


if __name__ == "__main__":
    args = parse_args()
    params = json.load(open(args.config))

    cores_per_replica = params["cores_per_replica"]

    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]

    params["optimizer"] = optax.chain(
        optax.scale(1),
        clip_by_global_norm(1),
        optax.scale_by_adam(),
        optax.additive_weight_decay(0),
        optax.scale(-1),
        optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
    )

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
    def test_node_classification(self):
        # If node has more than 2 neighbors --> class 1, otherwise class 0.
        # Graph structure:
        # 1         4
        # | \     / |
        # |  0 - 3  |
        # | /     \ |
        # 2         5

        edges = np.array([
            [0, 1],
            [1, 2],
            [2, 0],
            [0, 3],
            [3, 4],
            [4, 5],
            [5, 3],
        ],
                         dtype=np.int32)

        n_node = edges.max() + 1
        n_edge = edges.shape[0]
        g = jraph.GraphsTuple(senders=edges[:, 0],
                              receivers=edges[:, 1],
                              edges=np.ones((edges.shape[0], 1),
                                            dtype=np.float32),
                              nodes=np.ones((n_node, 1), dtype=np.float32),
                              n_node=np.array([n_node], dtype=np.int32),
                              n_edge=np.array([n_edge], dtype=np.int32),
                              globals=None)
        g = gn.add_reverse_edges(g)
        targets = np.array([1, 0, 0, 1, 0, 0], dtype=np.int32)
        n_classes = 2

        def forward(graph, targets):
            model = gn.SimpleGraphNet(num_layers=5, layer_norm=False)
            graph = model(graph)
            nodes = graph.nodes
            logits = hk.Linear(n_classes)(nodes)
            pred = logits.argmax(axis=-1)
            accuracy = (pred == targets).mean()
            targets = jax.nn.one_hot(targets, n_classes, dtype=jnp.float32)
            return -jnp.mean(
                jnp.sum(jax.nn.log_softmax(logits, axis=-1) * targets,
                        axis=-1)), accuracy

        init_fn, apply_fn = hk.without_apply_rng(hk.transform(forward))
        rng = hk.PRNGSequence(0)
        params = init_fn(next(rng), g, targets)

        optimizer = optax.chain(optax.scale_by_adam(), optax.scale(-1e-3))
        opt_state = optimizer.init(params)
        apply_fn = jax.jit(apply_fn)
        for i in range(500):
            (loss, acc), grad = jax.value_and_grad(apply_fn,
                                                   has_aux=True)(params, g,
                                                                 targets)
            updates, opt_state = optimizer.update(grad, opt_state, params)
            params = optax.apply_updates(params, updates)
            if (i + 1) % 100 == 0:
                logging.info('Step %d, loss %.8f, accuracy %.4f', i + 1, loss,
                             acc)
        self.assertLess(loss, 0.01)
        self.assertEqual(acc, 1.0)
Esempio n. 12
0
def init(key, X, lr):
    params, state = forward.init(key, X, True)
    optimizer = optax.chain(optax.scale_by_adam(),
                            optax.add_decayed_weights(0.03), optax.scale(-lr))
    opt_state = optimizer.init(params)
    return params, state, opt_state, optimizer
Esempio n. 13
0
    ckpt_every = params["ckpt_every"]
    keep_every = params["keep_every"]
    eval_tasks = params["eval_harness_tasks"]
    total_steps = params["total_steps"]

    pe = params["pe"]
    assert pe in ["fixed", "rotary", "t5"]

    warmup_steps = params["warmup_steps"]
    anneal_steps = params["anneal_steps"]
    lr = params["lr"]
    end_lr = params["end_lr"]
    weight_decay = params["weight_decay"]

    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps), clip_by_global_norm(1),
        optax.scale_by_adam(), additive_weight_decay(weight_decay),
        optax.scale(-1),
        optax.scale_by_schedule(
            util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)))

    params["optimizer"] = opt

    start = time.time()
    tpu_size = jax.device_count()
    if tpu_size < cores_per_replica:
        msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})"
        raise ValueError(msg)
    print(f"jax devices: {tpu_size}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")
Esempio n. 14
0
  def __init__(self,
               player_id,
               state_representation_size,
               num_actions,
               hidden_layers_sizes,
               reservoir_buffer_capacity,
               anticipatory_param,
               batch_size=128,
               rl_learning_rate=0.01,
               sl_learning_rate=0.01,
               min_buffer_size_to_learn=1000,
               learn_every=64,
               optimizer_str="sgd",
               **kwargs):
    """Initialize the `NFSP` agent."""
    self.player_id = player_id
    self._num_actions = num_actions
    self._layer_sizes = hidden_layers_sizes
    self._batch_size = batch_size
    self._learn_every = learn_every
    self._anticipatory_param = anticipatory_param
    self._min_buffer_size_to_learn = min_buffer_size_to_learn

    self._reservoir_buffer = ReservoirBuffer(reservoir_buffer_capacity)
    self._prev_timestep = None
    self._prev_action = None

    # Step counter to keep track of learning.
    self._step_counter = 0

    # Inner RL agent
    kwargs.update({
        "batch_size": batch_size,
        "learning_rate": rl_learning_rate,
        "learn_every": learn_every,
        "min_buffer_size_to_learn": min_buffer_size_to_learn,
        "optimizer_str": optimizer_str,
    })
    self._rl_agent = dqn.DQN(player_id, state_representation_size,
                             num_actions, hidden_layers_sizes, **kwargs)

    # Keep track of the last training loss achieved in an update step.
    self._last_rl_loss_value = lambda: self._rl_agent.loss
    self._last_sl_loss_value = None

    # Average policy network.
    def network(x):
      mlp = hk.nets.MLP(self._layer_sizes + [num_actions])
      return mlp(x)

    self.hk_avg_network = hk.without_apply_rng(hk.transform(network))

    def avg_network_policy(param, info_state):
      action_values = self.hk_avg_network.apply(param, info_state)
      action_probs = jax.nn.softmax(action_values, axis=1)
      return action_values, action_probs

    self._avg_network_policy = jax.jit(avg_network_policy)

    rng = jax.random.PRNGKey(42)
    x = jnp.ones([1, state_representation_size])
    self.params_avg_network = self.hk_avg_network.init(rng, x)
    self.params_avg_network = jax.device_put(self.params_avg_network)

    self._savers = [
        ("q_network", self._rl_agent.params_q_network),
        ("avg_network", self.params_avg_network)
    ]

    if optimizer_str == "adam":
      opt_init, opt_update = optax.chain(
          optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
          optax.scale(sl_learning_rate))
    elif optimizer_str == "sgd":
      opt_init, opt_update = optax.sgd(sl_learning_rate)
    else:
      raise ValueError("Not implemented. Choose from ['adam', 'sgd'].")
    self._opt_update_fn = self._get_update_func(opt_update)
    self._opt_state = opt_init(self.params_avg_network)
    self._loss_and_grad = jax.value_and_grad(self._loss_avg, has_aux=False)

    self._sample_episode_policy()
    self._jit_update = jax.jit(self.get_update())
  "pe_rotary_dims": 64,
  "early_cast": True,
  "seq": 2048,
  "cores_per_replica": 1,  # only running on one GPU
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

devices = np.array([jax.devices()[0]]).reshape((1, 1))
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

network = CausalTransformer(params)

start = time.time()

# here we load a checkpoint which was written with 8 shards into 1 shard
network.state = read_ckpt(network.state, "step_383500/", 8, shards_out=cores_per_replica)

# move the state to CPU/system memory so it's not duplicated by xmap
network.state = jax.device_put(network.state, jax.devices("cpu")[0])
Esempio n. 16
0
def sgld(learning_rate: float = 1e-2, random_seed: int = 0):
    return optax.chain(
        optax.scale(-learning_rate),
        optax.add_noise(np.sqrt(2 * np.abs(learning_rate)), 0, random_seed),
    )
Esempio n. 17
0
  def optimize(self, non_convex_bound: nonconvex.NonConvexBound
               )->bound_propagation.Bound:
    # If we are going to actually perform optimization, define the function to
    # minimize (either the primal, or the negative of the dual),
    # its gradient and the projection function to use.
    if self._num_steps:
      if self._optimize_dual:
        def fun_to_opt(opt_vars, objectives):
          _, dual_vals = non_convex_bound.dual(opt_vars, objectives)
          return -jnp.sum(dual_vals)
      else:
        def fun_to_opt(opt_vars, objectives):
          final_acts = non_convex_bound.evaluate(opt_vars)
          obj = jnp.sum(final_acts * objectives)
          return obj
      grad_fun = jax.grad(fun_to_opt)
      proj_fun = lambda x: jnp.clip(x, 0., 1.)

      # Define the optimizer. Because we are minimizing the objective function,
      # we will scale the gradient by a negative step size.
      tx = optax.scale(-self._step_size)

    # Define the function to optimize a chunk of the nodes of the activation.
    def optimize_chunk(batch_index: int,
                       current_bounds: Tuple[Tensor, Tensor]
                       ) -> Tuple[Tensor, Tensor]:
      var_shapes, batch_objectives = _create_opt_problems(
          non_convex_bound, batch_index, self._max_parallel_nodes)

      var_set = {key: 0.5 * jnp.ones(shape)
                 for key, shape in var_shapes.items()}

      # Perform the optimization.
      if self._num_steps:
        state = tx.init(var_set)

        def opt_step(_, state_and_var):
          state, var_set = state_and_var
          grads = grad_fun(var_set, batch_objectives)
          updates, new_state = tx.update(grads, state, var_set)
          unc_var_set = optax.apply_updates(var_set, updates)
          new_var_set = jax.tree_map(proj_fun, unc_var_set)
          return new_state, new_var_set

        _, var_set = jax.lax.fori_loop(0, self._num_steps, opt_step,
                                       (state, var_set))

      # Compute the resulting bound and unpack it.
      _, dual_vals = non_convex_bound.dual(jax.lax.stop_gradient(var_set),
                                           batch_objectives)

      batch_lbs, batch_ubs = _unpack_opt_problem(dual_vals, batch_objectives)

      current_lbs, current_ubs = current_bounds

      lbs = batch_lbs + current_lbs
      ubs = batch_ubs + current_ubs

      return (lbs, ubs)

    return _chunked_optimization(non_convex_bound.shape,
                                 self._max_parallel_nodes,
                                 optimize_chunk)
Esempio n. 18
0
def make_optimizer():
    """SGD with nesterov momentum and a custom lr schedule."""
    return optax.chain(
        optax.trace(decay=FLAGS.optimizer_momentum,
                    nesterov=FLAGS.optimizer_use_nesterov),
        optax.scale_by_schedule(lr_schedule), optax.scale(-1))
Esempio n. 19
0
 def init(self, parameters: Weights) -> GenericGradientState:
     return GenericGradientState(scale(**asdict(self)).init(parameters))