Ejemplo n.º 1
0
  def test_graph_conditioned_transformer_learns(self):
    graphs = jraph.GraphsTuple(
        nodes=np.ones((4, 3), dtype=np.float32),
        edges=np.ones((3, 1), dtype=np.float32),
        senders=np.array([0, 2, 3], dtype=np.int32),
        receivers=np.array([1, 3, 2], dtype=np.int32),
        n_node=np.array([2, 2], dtype=np.int32),
        n_edge=np.array([1, 2], dtype=np.int32),
        globals=None,
        )
    seqs = np.array([[1, 2, 2, 0],
                     [1, 3, 3, 3]], dtype=np.int32)
    vocab_size = seqs.max() + 1
    embed_dim = 8
    max_graph_size = graphs.n_node.max()

    logging.info('Training seqs: %r', seqs)

    x = seqs[:, :-1]
    y = seqs[:, 1:]

    def model_fn(vocab_size, embed_dim):
      return models.Graph2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=embed_dim,
          num_layers=2,
          num_heads=4,
          cutoffs=[],
          gnn_embed_dim=embed_dim,
          gnn_num_layers=2)

    def forward(graphs, inputs, labels, max_graph_size):
      input_mask = (labels != 0).astype(jnp.float32)
      return model_fn(vocab_size, embed_dim).loss(
          graphs, max_graph_size, False, inputs, labels, mask=input_mask)

    init_fn, apply_fn = hk.transform_with_state(forward)
    rng = hk.PRNGSequence(8)
    params, state = init_fn(next(rng), graphs, x, y, max_graph_size)

    def apply(*args, **kwargs):
      out, state = apply_fn(*args, **kwargs)
      return out[0], (out[1], state)
    apply = jax.jit(apply, static_argnums=6)

    optimizer = optax.chain(
        optax.scale_by_adam(),
        optax.scale(-1e-3))
    opt_state = optimizer.init(params)
    for i in range(500):
      (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)(
          params, state, next(rng), graphs, x, y, max_graph_size)
      metrics, state = model_state
      updates, opt_state = optimizer.update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)
      if (i + 1) % 100 == 0:
        logging.info(
            'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()})
    logging.info('Loss: %.8f', loss)
    self.assertLess(loss, 1.0)
Ejemplo n.º 2
0
    def build_optimizer(self,
                        clip=15.0,
                        lr=5e-4,
                        warmup=2000,
                        cosine_decay_steps=None,
                        optimizer_name="adabelief") -> GradientTransformation:
        chain = []
        if optimizer_name == "adabelief":
            chain.append(util.scale_by_belief())
        elif optimizer_name == "adam":
            chain.append(optax.scale_by_adam())
        else:
            assert 0

        # Make sure to use the negative learning rate so that we minimize
        if warmup and warmup > 0:
            warmup_schedule = partial(util.linear_warmup_lr_schedule,
                                      warmup=warmup,
                                      lr_decay=1.0,
                                      lr=-lr)
            chain.append(optax.scale_by_schedule(warmup_schedule))
        else:
            chain.append(optax.scale(-lr))

        if cosine_decay_steps and cosine_decay_steps > 0:
            cosine_lr = optax.cosine_decay_schedule(
                init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1)
            chain.append(optax.scale_by_schedule(cosine_lr))

        if clip and clip > 0:
            chain.append(optax.clip(clip))

        return optax.chain(*chain)
Ejemplo n.º 3
0
def optimizer(hyperparameters):
    opt_init_fn, opt_update_fn = optax.chain(
        optax.scale_by_adam(b1=1.0 - hyperparameters.one_minus_beta_1,
                            b2=0.999,
                            eps=hyperparameters.epsilon),
        optax.scale(-hyperparameters.learning_rate))
    return opt_init_fn, opt_update_fn
Ejemplo n.º 4
0
def train(*, data_folder, batch_size, epochs, learning_rate, weight_decay,
          seed, max_norm, text_vocab, text_dim, text_depth, text_heads,
          audio_dim, audio_depth, audio_heads):
    # rng

    rng_key = random.PRNGKey(seed)

    # data

    dataset = PairTextSpectrogramDataset(data_folder)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    collate_fn=pair_text_spectrogram_dataset_collate_fn,
                    drop_last=True,
                    shuffle=True)

    # model

    model = CLAP(text_vocab=text_vocab,
                 text_dim=text_dim,
                 text_depth=text_depth,
                 text_heads=text_heads,
                 audio_dim=audio_dim,
                 audio_depth=audio_depth,
                 audio_heads=audio_heads)

    # optimizer

    exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1,
                                                     params)

    optim = chain(clip_by_global_norm(max_norm), scale_by_adam(eps=1e-4),
                  add_decayed_weights(weight_decay, exclude_bias),
                  scale(-learning_rate))

    # init

    audio, audio_mask, text, text_mask = next(iter(dl))

    params = model.init(rng_key, text, audio, text_mask, audio_mask)
    optim_state = optim.init(params)

    # loss function, for use with value_and_grad

    @jit
    @value_and_grad
    def loss_fn(params, text, audio, text_mask, audio_mask):
        return model.apply(params, text, audio, text_mask, audio_mask)

    # train loop

    for _ in range(epochs):
        for audio, audio_mask, text, text_mask in dl:
            loss, grads = loss_fn(params, text, audio, text_mask, audio_mask)
            updates, optim_state = optim.update(grads, optim_state, params)
            params = apply_updates(params, updates)
            print(f'loss: {loss}')
Ejemplo n.º 5
0
def make_optimizer(optimizer_config, lr_schedule):
  """Construct the optax optimizer with given LR schedule."""
  if (optimizer_config.get('decay_pos_embs') is None or
      optimizer_config.decay_pos_embs):
    # Decay learned position embeddings by default.
    weight_decay_exclude_names = ['b']
  else:
    weight_decay_exclude_names = ['pos_embs', 'b']

  optax_chain = []
  if optimizer_config.max_norm > 0:
    optax_chain.append(
        optax.clip_by_global_norm(optimizer_config.max_norm))

  if optimizer_config.optimizer == 'adam':
    # See: https://arxiv.org/abs/1412.6980
    optax_chain.extend([
        optax.scale_by_adam(**optimizer_config.adam_kwargs),
        add_weight_decay(
            optimizer_config.weight_decay,
            exclude_names=weight_decay_exclude_names)
    ])
  elif optimizer_config.optimizer == 'lamb':
    # See: https://arxiv.org/abs/1904.00962
    optax_chain.extend([
        optax.scale_by_adam(**optimizer_config.lamb_kwargs),
        add_weight_decay(
            optimizer_config.weight_decay,
            exclude_names=weight_decay_exclude_names),
        optax.scale_by_trust_ratio()
    ])
  else:
    raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}')

  # Scale by the (negative) learning rate.
  optax_chain.extend([
      optax.scale_by_schedule(lr_schedule),
      optax.scale(-1),
  ])

  return optax.chain(*optax_chain)
Ejemplo n.º 6
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: ppo_networks.PPONetworks,
        dataset: Iterator[reverb.ReplaySample],
        logger_fn: loggers.LoggerFactory,
        environment_spec: specs.EnvironmentSpec,
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:
        del environment_spec, replay_client

        if callable(self._config.learning_rate):
            optimizer = optax.chain(
                optax.clip_by_global_norm(self._config.max_gradient_norm),
                optax.scale_by_adam(eps=self._config.adam_epsilon),
                optax.scale_by_schedule(self._config.learning_rate),
                optax.scale(-1))
        else:
            optimizer = optax.chain(
                optax.clip_by_global_norm(self._config.max_gradient_norm),
                optax.scale_by_adam(eps=self._config.adam_epsilon),
                optax.scale(-self._config.learning_rate))

        return learning.PPOLearner(
            ppo_networks=networks,
            iterator=dataset,
            discount=self._config.discount,
            entropy_cost=self._config.entropy_cost,
            value_cost=self._config.value_cost,
            max_abs_reward=self._config.max_abs_reward,
            ppo_clipping_epsilon=self._config.ppo_clipping_epsilon,
            clip_value=self._config.clip_value,
            gae_lambda=self._config.gae_lambda,
            counter=counter,
            random_key=random_key,
            optimizer=optimizer,
            num_epochs=self._config.num_epochs,
            num_minibatches=self._config.num_minibatches,
            logger=logger_fn('learner'),
        )
Ejemplo n.º 7
0
def main(argv):
    del argv

    learning_rate = 1e-2
    batch_size = 64
    input_size = 8
    n_training_steps = 100

    # Random number generator sequence.
    key_seq = hk.PRNGSequence(1729)

    # A simple Linear function.
    def forward_pass(x):
        return hk.Linear(10)(x)

    network = hk.without_apply_rng(hk.transform(forward_pass))

    # Some arbitrary loss.
    def mean_square_loss(params, x):
        output = network.apply(params, x)
        loss = jnp.sum(output**2)
        return loss

    # Construct a simple Adam optimiser using the transforms in optax.
    # You could also just use the `optax.adam` alias, but we show here how
    # to do so manually so that you may construct your own `custom` optimiser.
    opt_init, opt_update = optax.chain(
        # Set the parameters of Adam. Note the learning_rate is not here.
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        # Put a minus sign to *minimise* the loss.
        optax.scale(-learning_rate))

    # Initialise the model's parameters and the optimiser's state.
    # The `state` of an optimiser contains all statistics used by the
    # stateful transformations in the `chain` (in this case just `scale_by_adam`).
    params = network.init(next(key_seq), jnp.zeros([1, input_size]))
    opt_state = opt_init(params)

    # Minimise the loss.
    for step in range(n_training_steps):
        # Get input. Learn to minimize the input to 0.
        data = jax.random.normal(next(key_seq), [batch_size, input_size])
        # Compute gradient and loss.
        loss, grad = jax.value_and_grad(mean_square_loss)(params, data)
        print(f'Loss[{step}] = {loss}')
        # Transform the gradients using the optimiser.
        updates, opt_state = opt_update(grad, opt_state, params)
        # Update parameters.
        params = optax.apply_updates(params, updates)
Ejemplo n.º 8
0
    def _create_jax_optimizer(self):
        import optax
        process = []
        if isinstance(self.learning_rate, LearningRateSchedule):
            scheduler = self.learning_rate._create_jax_schedule()
            process.append(optax.scale_by_schedule(scheduler))
            last_process = optax.scale(-1.0)
        else:
            lr = self.learning_rate
            last_process = optax.scale(-1.0 * lr)

        process.append(
            optax.scale_by_adam(b1=self.beta1, b2=self.beta2,
                                eps=self.epsilon))
        process.append(last_process)
        return optax.chain(*process)
Ejemplo n.º 9
0
  def test_bow_transformer_learns(self):
    bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1],
                    [0, 1, 0, 0, 1, 0, 1, 0],
                    [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32)
    seqs = np.array([[1, 2, 2, 3, 0, 0],
                     [1, 2, 4, 5, 6, 0],
                     [1, 3, 3, 5, 4, 2]], dtype=np.int32)
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    vocab_size = seqs.max() + 1

    def model_fn():
      return models.Bow2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=16,
          num_layers=2,
          num_heads=4,
          cutoffs=[])

    def loss_fn(bow, inputs, labels):
      mask = (labels != 0).astype(jnp.float32)
      return model_fn().loss(bow, inputs, labels, mask=mask)

    init_fn, apply_fn = hk.transform_with_state(loss_fn)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), bow, x, y)

    def apply(*args, **kwargs):
      out, state = apply_fn(*args, **kwargs)
      return out[0], (out[1], state)
    value_and_grad = jax.jit(jax.value_and_grad(apply, has_aux=True))

    optimizer = optax.chain(
        optax.scale_by_adam(),
        optax.scale(-1e-3))
    opt_state = optimizer.init(params)
    for i in range(800):
      (loss, model_state), grad = value_and_grad(
          params, state, next(key), bow, x, y)
      metrics, state = model_state
      updates, opt_state = optimizer.update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)
      if (i + 1) % 100 == 0:
        logging.info('Step %d, %r', i + 1,
                     {k: float(v) for k, v in metrics.items()})
    logging.info('Loss: %.8f', loss)
    self.assertLess(loss, 0.1)
Ejemplo n.º 10
0
def init_optimizer_state(workload: spec.Workload,
                         model_params: spec.ParameterContainer,
                         model_state: spec.ModelAuxiliaryState,
                         hyperparameters: spec.Hyperparameters,
                         rng: spec.RandomState) -> spec.OptimizerState:
  del model_params
  del model_state
  del rng
  params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
                                   workload.param_shapes)
  opt_init_fn, opt_update_fn = optax.chain(
      optax.scale_by_adam(
          b1=1.0 - hyperparameters.one_minus_beta_1,
          b2=0.999,
          eps=hyperparameters.epsilon),
      optax.scale(-hyperparameters.learning_rate))
  return jax_utils.replicate(opt_init_fn(params_zeros_like)), opt_update_fn
Ejemplo n.º 11
0
def build_model(params, tpu_name, region, preemptible, version=1):
    gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
    cores_per_replica = params["cores_per_replica"]
    tpu_size = params["tpu_size"]

    warmup_steps = params["warmup_steps"]
    anneal_steps = params["anneal_steps"]
    lr = params["lr"]
    end_lr = params["end_lr"]
    weight_decay = params["weight_decay"]

    assert tpu_size in [8, 32, 128, 256, 512]

    create_tpu(tpu_name, region, f"v3-{tpu_size}", preemptible)
    assert wait_til(tpu_name, region, {'state': 'READY', 'health': 'HEALTHY'})

    conns = get_connection(tpu_name, region)

    assert len(conns) * 8 == tpu_size, "wrong size TPU for config"

    head_info = ray.init(include_dashboard=False, object_store_memory=10**9)
    address = head_info['redis_address']

    with multiprocessing.pool.ThreadPool(processes=len(conns)) as p:
        p.map(functools.partial(start_ray, address=address, version=version), conns)

    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps),
        clip_by_global_norm(1),
        optax.scale_by_adam(),
        additive_weight_decay(weight_decay),
        optax.scale(-1),
        optax.scale_by_schedule(util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr))
    )

    params["optimizer"] = opt

    if version == 2:
        model_fn = functools.partial(CausalTransformerV2, params)
    elif version == 1:
        model_fn = functools.partial(CausalTransformer, params)
    else:
        raise Exception(f"Version {version} does not exist")

    t = TPUCluster((tpu_size // cores_per_replica, cores_per_replica), len(conns), model_fn, version=version)
    return t
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay,
                       max_norm):

    tx = optax.chain(optax.clip_by_global_norm(max_norm),
                     optax.scale_by_adam(),
                     optax.additive_weight_decay(weight_decay),
                     optax.scale_by_schedule(lr_schedule_fn))

    params = model.init(rng,
                        jax.numpy.ones((1, img_size, img_size, 3)),
                        is_training=False)

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )
    return train_state
Ejemplo n.º 13
0
def get_optimizer(config):
    warm_up_poly = optax.polynomial_schedule(
        init_value=1 / config['warmup_iter'],
        end_value=1,
        power=1,
        transition_steps=config['warmup_iter'])
    exp_decay = optax.exponential_decay(
        init_value=config['adam_lr'],
        transition_steps=config['decay_steps'],
        decay_rate=config['lr_decay_rate'],
        transition_begin=0)  #config['warmup_iter'])
    opt = optax.chain(
        # clip_by_global_norm(max_norm),
        optax.scale_by_adam(b1=config['adam_beta_1'],
                            b2=config['adam_beta_2'],
                            eps=config['adam_eps']),
        optax.scale_by_schedule(warm_up_poly),
        optax.scale_by_schedule(exp_decay),
        optax.scale(-1))
    return opt
Ejemplo n.º 14
0
        def amsgrad():
            adam = optax.scale_by_adam()

            def init_fn(params):
                return adam.init(params)

            def update_fn(updates, state, params=None):
                prev_nu = state.nu
                _, state = adam.update(updates, state, params)
                curr_nu = state.nu
                nu_hat = jax.tree_multimap(jnp.maximum, curr_nu, prev_nu)
                updates = jax.tree_multimap(
                    lambda m, v: m / (jnp.sqrt(v + 0.0) + 1e-8), state.mu,
                    nu_hat)

                return updates, optax.ScaleByAdamState(count=state.count,
                                                       mu=state.mu,
                                                       nu=nu_hat)

            return optax.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 15
0
    def __init__(self,
                 learning_rate_fn: Union[float, int, optax.Schedule],
                 normalize_fn: Optional[NormalizeFn] = None,
                 beta1: float = .9,
                 beta2: float = .999,
                 epsilon: float = 1e-9):
        # Accept schedules, as well as scalar values.
        if isinstance(learning_rate_fn, (float, int)):
            lr = float(learning_rate_fn)
            learning_rate_fn = lambda _: lr
        # Normalization.
        def update_fn(updates, state, params=None):
            del params
            updates = jax.tree_map(normalize_fn or (lambda x: x), updates)
            return updates, state

        gradient_transformation = optax.chain(
            optax.GradientTransformation(lambda _: optax.EmptyState(),
                                         update_fn),
            optax.scale_by_adam(b1=beta1, b2=beta2, eps=epsilon),
            optax.scale_by_schedule(learning_rate_fn), optax.scale(-1.))
        super(Adam, self).__init__(gradient_transformation)
Ejemplo n.º 16
0
def main(argv):
    del argv

    learning_rate = 1e-2
    n_training_steps = 100

    # Random number generator sequence.
    rng = jax.random.PRNGKey(0)
    rng1, rng2 = jax.random.split(rng)

    # Create a one linear layer instance.
    model = nn.Dense(features=5)

    # Initialise the parameters.
    params = model.init(rng2, jax.random.normal(rng1, (10, )))

    # Set problem dimensions.
    nsamples = 20
    xdim = 10
    ydim = 5

    # Generate random ground truth w and b.
    w = jax.random.normal(rng1, (xdim, ydim))
    b = jax.random.normal(rng2, (ydim, ))

    # Generate samples with additional noise.
    ksample, knoise = jax.random.split(rng1)
    x_samples = jax.random.normal(ksample, (nsamples, xdim))
    y_samples = jnp.dot(x_samples, w) + b
    y_samples += 0.1 * jax.random.normal(knoise, (nsamples, ydim))

    # Define an MSE loss function.
    def make_mse_func(x_batched, y_batched):
        def mse(params):
            # Define the squared loss for a single (x, y) pair.
            def squared_error(x, y):
                pred = model.apply(params, x)
                return jnp.inner(y - pred, y - pred) / 2.0

            # Vectorise the squared error and compute the average of the loss.
            return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched),
                            axis=0)

        return jax.jit(mse)  # `jit` the result.

    # Instantiate the sampled loss.
    loss = make_mse_func(x_samples, y_samples)

    # Construct a simple Adam optimiser using the transforms in optax.
    # You could also just use the `optax.adam` alias, but we show here how
    # to do so manually so that you may construct your own `custom` optimiser.
    tx = optax.chain(
        # Set the parameters of Adam. Note the learning_rate is not here.
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        # Put a minus sign to *minimise* the loss.
        optax.scale(-learning_rate))

    # Create optimiser state.
    opt_state = tx.init(params)
    # Compute the gradient of the loss function.
    loss_grad_fn = jax.value_and_grad(loss)

    # Minimise the loss.
    for step in range(n_training_steps):
        # Compute gradient of the loss.
        loss_val, grads = loss_grad_fn(params)
        # Update the optimiser state, create an update to the params.
        updates, opt_state = tx.update(grads, opt_state)
        # Update the parameters.
        params = optax.apply_updates(params, updates)
        print(f'Loss[{step}] = {loss_val}')
Ejemplo n.º 17
0
    scaler = StandardScaler()
    scaler.fit(X_train)
    X_train_s = scaler.transform(X_train)
    X_test_s = scaler.transform(X_test)

    X_train_s = torch.tensor(X_train_s, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.float32)

    train_dataloader = DataLoader(TensorDataset(X_train_s, y_train),
                                  batch_size=batch_size,
                                  shuffle=True)

    learning_rate = 0.001
    optimizer = optax.chain(
        # Set the parameters of Adam. Note the learning_rate is not here.
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        # Put a minus sign to *minimise* the loss.
        optax.scale(-learning_rate),
    )

    logit_d = Classifier(num_layers=3, hidden_dim=128, use_residual=True)
    params, opt_state = init_fn(X_train_s.shape, jax.random.PRNGKey(seed),
                                logit_d, optimizer)

    train_step = get_train_step(loss, optimizer)

    print("Test accuracy: {:.3f}".format(
        jax.numpy.mean((jax.nn.sigmoid(
            logit_d.apply({"params": params}, np.array(X_test_s))) > 0.5
                        ).flatten() == y_test)))
Ejemplo n.º 18
0
def train_model(train_file_pattern,
                test_file_pattern,
                max_files_to_load=None,
                n_epochs=1000,
                time_index=9,
                learning_rate=1e-4,
                grad_clip=1.0,
                measurement_store_interval=1000,
                checkpoint_path=None):
    """Trains GraphModel using tensorflow.

  Args:
    train_file_pattern: pattern matching the files with the training data.
    test_file_pattern: pattern matching the files with the test data.
    max_files_to_load: the maximum number of train and test files to load.
      If None, all files will be loaded.
    n_epochs: the number of passes through the training dataset (epochs).
    time_index: the time index (0-9) of the target mobilities.
    learning_rate: the learning rate used by the optimizer.
    grad_clip: all gradients are clipped to the given value.
    measurement_store_interval: number of steps between storing objective values
      (loss and correlation).
    checkpoint_path: ignored by this implementation.
  """
    if checkpoint_path:
        logging.warning('The checkpoint_path argument is ignored.')
    random.seed(42)
    np.random.seed(42)
    # Loads train and test dataset.
    dataset_kwargs = dict(time_index=time_index,
                          max_files_to_load=max_files_to_load)
    logging.info('Load training data')
    training_data = load_data(train_file_pattern, **dataset_kwargs)
    logging.info('Load test data')
    test_data = load_data(test_file_pattern, **dataset_kwargs)
    logging.info('Finished loading data')

    network = hk.without_apply_rng(hk.transform(network_definition))
    params = network.init(jax.random.PRNGKey(42), training_data[0][0])

    opt_init, opt_update = optax.chain(optax.clip_by_global_norm(grad_clip),
                                       optax.scale_by_adam(0.9, 0.999, 1e-8),
                                       optax.scale(-learning_rate))
    opt_state = opt_init(params)

    network_apply = jax.jit(network.apply)

    @jax.jit
    def loss_fn(params, graph, targets, mask):
        decoded_nodes = network_apply(params, graph) * mask
        return (jnp.sum((decoded_nodes - targets)**2 * mask) / jnp.sum(mask))

    @jax.jit
    def update(params, opt_state, graph, targets, mask):
        loss, grads = jax.value_and_grad(loss_fn)(params, graph, targets, mask)
        updates, opt_state = opt_update(grads, opt_state)
        return optax.apply_updates(params, updates), opt_state, loss

    train_stats = []
    i = 0
    logging.info('Start training')
    for epoch in range(n_epochs):
        logging.info('Start epoch %r', epoch)
        random.shuffle(training_data)
        for graph, targets, mask in training_data:
            graph = apply_random_rotation(graph)
            params, opt_state, loss = update(params, opt_state, graph, targets,
                                             mask)
            train_stats.append(loss)

            if (i + 1) % measurement_store_interval == 0:
                logging.info('Start evaluation run')
                test_stats = []
                for test_graph, test_targets, test_mask in test_data:
                    predictions = network_apply(params, test_graph)
                    test_stats.append(
                        np.corrcoef(predictions[test_mask == 1],
                                    test_targets[test_mask == 1])[0, 1])
                logging.info('Train loss %r', np.mean(train_stats))
                logging.info('Test correlation %r', np.mean(test_stats))
                train_stats = []
            i += 1
Ejemplo n.º 19
0
    cores_per_replica = params["cores_per_replica"]

    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    params["sampler"] = nucleaus_sample
    opt = optax.chain(optax.scale(1 / gradient_accumulation_steps),
                      clip_by_global_norm(1), optax.scale_by_adam(),
                      optax.additive_weight_decay(0), optax.scale(-1),
                      optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)))

    params["optimizer"] = opt

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
        meta = json.load(f)
Ejemplo n.º 20
0
def init(key, X, lr):
    params, state = forward.init(key, X, True)
    optimizer = optax.chain(optax.scale_by_adam(),
                            optax.add_decayed_weights(0.03), optax.scale(-lr))
    opt_state = optimizer.init(params)
    return params, state, opt_state, optimizer
Ejemplo n.º 21
0
    keep_every = params["keep_every"]
    eval_tasks = params["eval_harness_tasks"]
    total_steps = params["total_steps"]

    pe = params["pe"]
    assert pe in ["fixed", "rotary", "t5"]

    warmup_steps = params["warmup_steps"]
    anneal_steps = params["anneal_steps"]
    lr = params["lr"]
    end_lr = params["end_lr"]
    weight_decay = params["weight_decay"]

    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps), clip_by_global_norm(1),
        optax.scale_by_adam(), additive_weight_decay(weight_decay),
        optax.scale(-1),
        optax.scale_by_schedule(
            util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)))

    params["optimizer"] = opt

    start = time.time()
    tpu_size = jax.device_count()
    if tpu_size < cores_per_replica:
        msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})"
        raise ValueError(msg)
    print(f"jax devices: {tpu_size}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (tpu_size // cores_per_replica, cores_per_replica)
    def test_node_classification(self):
        # If node has more than 2 neighbors --> class 1, otherwise class 0.
        # Graph structure:
        # 1         4
        # | \     / |
        # |  0 - 3  |
        # | /     \ |
        # 2         5

        edges = np.array([
            [0, 1],
            [1, 2],
            [2, 0],
            [0, 3],
            [3, 4],
            [4, 5],
            [5, 3],
        ],
                         dtype=np.int32)

        n_node = edges.max() + 1
        n_edge = edges.shape[0]
        g = jraph.GraphsTuple(senders=edges[:, 0],
                              receivers=edges[:, 1],
                              edges=np.ones((edges.shape[0], 1),
                                            dtype=np.float32),
                              nodes=np.ones((n_node, 1), dtype=np.float32),
                              n_node=np.array([n_node], dtype=np.int32),
                              n_edge=np.array([n_edge], dtype=np.int32),
                              globals=None)
        g = gn.add_reverse_edges(g)
        targets = np.array([1, 0, 0, 1, 0, 0], dtype=np.int32)
        n_classes = 2

        def forward(graph, targets):
            model = gn.SimpleGraphNet(num_layers=5, layer_norm=False)
            graph = model(graph)
            nodes = graph.nodes
            logits = hk.Linear(n_classes)(nodes)
            pred = logits.argmax(axis=-1)
            accuracy = (pred == targets).mean()
            targets = jax.nn.one_hot(targets, n_classes, dtype=jnp.float32)
            return -jnp.mean(
                jnp.sum(jax.nn.log_softmax(logits, axis=-1) * targets,
                        axis=-1)), accuracy

        init_fn, apply_fn = hk.without_apply_rng(hk.transform(forward))
        rng = hk.PRNGSequence(0)
        params = init_fn(next(rng), g, targets)

        optimizer = optax.chain(optax.scale_by_adam(), optax.scale(-1e-3))
        opt_state = optimizer.init(params)
        apply_fn = jax.jit(apply_fn)
        for i in range(500):
            (loss, acc), grad = jax.value_and_grad(apply_fn,
                                                   has_aux=True)(params, g,
                                                                 targets)
            updates, opt_state = optimizer.update(grad, opt_state, params)
            params = optax.apply_updates(params, updates)
            if (i + 1) % 100 == 0:
                logging.info('Step %d, loss %.8f, accuracy %.4f', i + 1, loss,
                             acc)
        self.assertLess(loss, 0.01)
        self.assertEqual(acc, 1.0)
Ejemplo n.º 23
0
  def __init__(self,
               player_id,
               state_representation_size,
               num_actions,
               hidden_layers_sizes,
               reservoir_buffer_capacity,
               anticipatory_param,
               batch_size=128,
               rl_learning_rate=0.01,
               sl_learning_rate=0.01,
               min_buffer_size_to_learn=1000,
               learn_every=64,
               optimizer_str="sgd",
               **kwargs):
    """Initialize the `NFSP` agent."""
    self.player_id = player_id
    self._num_actions = num_actions
    self._layer_sizes = hidden_layers_sizes
    self._batch_size = batch_size
    self._learn_every = learn_every
    self._anticipatory_param = anticipatory_param
    self._min_buffer_size_to_learn = min_buffer_size_to_learn

    self._reservoir_buffer = ReservoirBuffer(reservoir_buffer_capacity)
    self._prev_timestep = None
    self._prev_action = None

    # Step counter to keep track of learning.
    self._step_counter = 0

    # Inner RL agent
    kwargs.update({
        "batch_size": batch_size,
        "learning_rate": rl_learning_rate,
        "learn_every": learn_every,
        "min_buffer_size_to_learn": min_buffer_size_to_learn,
        "optimizer_str": optimizer_str,
    })
    self._rl_agent = dqn.DQN(player_id, state_representation_size,
                             num_actions, hidden_layers_sizes, **kwargs)

    # Keep track of the last training loss achieved in an update step.
    self._last_rl_loss_value = lambda: self._rl_agent.loss
    self._last_sl_loss_value = None

    # Average policy network.
    def network(x):
      mlp = hk.nets.MLP(self._layer_sizes + [num_actions])
      return mlp(x)

    self.hk_avg_network = hk.without_apply_rng(hk.transform(network))

    def avg_network_policy(param, info_state):
      action_values = self.hk_avg_network.apply(param, info_state)
      action_probs = jax.nn.softmax(action_values, axis=1)
      return action_values, action_probs

    self._avg_network_policy = jax.jit(avg_network_policy)

    rng = jax.random.PRNGKey(42)
    x = jnp.ones([1, state_representation_size])
    self.params_avg_network = self.hk_avg_network.init(rng, x)
    self.params_avg_network = jax.device_put(self.params_avg_network)

    self._savers = [
        ("q_network", self._rl_agent.params_q_network),
        ("avg_network", self.params_avg_network)
    ]

    if optimizer_str == "adam":
      opt_init, opt_update = optax.chain(
          optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
          optax.scale(sl_learning_rate))
    elif optimizer_str == "sgd":
      opt_init, opt_update = optax.sgd(sl_learning_rate)
    else:
      raise ValueError("Not implemented. Choose from ['adam', 'sgd'].")
    self._opt_update_fn = self._get_update_func(opt_update)
    self._opt_state = opt_init(self.params_avg_network)
    self._loss_and_grad = jax.value_and_grad(self._loss_avg, has_aux=False)

    self._sample_episode_policy()
    self._jit_update = jax.jit(self.get_update())
Ejemplo n.º 24
0
def make_adam_optimizer(lr_schedule, b1=0.9, b2=0.999, eps=1e-8):
    """Make Adam optimizer."""
    # Maximize log-prob instead of minimizing loss
    return optax.chain(optax.scale_by_adam(b1=b1, b2=b2, eps=eps),
                       optax.scale_by_schedule(lr_schedule))
Ejemplo n.º 25
0
def main(_):
    # Create the dataset.
    tokenizer = utils.init_tokenizer(FLAGS.dataset)
    graph_tokenizer = utils.init_graph_tokenizer()
    dataset_class = utils.get_dataset_class(FLAGS.dataset, FLAGS.model_type)
    has_graph = True if FLAGS.model_type == 'graph2text' else False
    local_devices = jax.local_devices()
    num_gpus = min(FLAGS.num_gpus, len(local_devices))

    if FLAGS.job_mode == 'train':
        train_dataset = dataset_class(tokenizer=tokenizer,
                                      graph_tokenizer=graph_tokenizer,
                                      batch_size=FLAGS.train_batch_size,
                                      subset='train',
                                      timesteps=FLAGS.train_timesteps,
                                      version=FLAGS.graph_data_version,
                                      shuffle_data=True,
                                      repeat=True,
                                      debug=FLAGS.debug)
        train_iter = iter(train_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.train_memory_size)
        optimizer = optax.chain(
            optax.clip_by_global_norm(FLAGS.grad_clip), optax.scale_by_adam(),
            optax.scale_by_schedule(
                functools.partial(utils.schedule,
                                  lr_schedule=FLAGS.lr_schedule,
                                  init_lr=FLAGS.init_lr,
                                  min_lr_ratio=FLAGS.min_lr_ratio,
                                  max_steps=FLAGS.max_steps)), optax.scale(-1))
        optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5)
        updater = Updater(loss_fn,
                          optimizer,
                          devices=local_devices[:num_gpus],
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _train(updater, train_iter, num_gpus)
    elif FLAGS.job_mode == 'eval':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=FLAGS.eval_batch_size,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.eval_timesteps,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=False,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.eval_memory_size)
        # only use one device for evaluation
        devices = local_devices[:1]
        updater = Updater(loss_fn,
                          optimizer=None,
                          devices=devices,
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _eval(updater, eval_iter)
    elif FLAGS.job_mode == 'sample':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=1,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.sample_length,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=True,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        _sample(eval_iter, tokenizer, local_devices[:num_gpus])
    elif FLAGS.job_mode == 'retrieve':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=1,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.eval_timesteps,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=False,
                                     graph_retrieval_dataset=True,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.eval_memory_size)
        # only use one device for evaluation
        devices = local_devices[:1]
        updater = Updater(loss_fn,
                          optimizer=None,
                          devices=devices,
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _retrieve(updater, eval_iter)
Ejemplo n.º 26
0
def main(unused_argv):
    validate_flags()
    logging.info(get_config())

    seed = FLAGS.seed if FLAGS.seed is not None else np.random.randint(
        0, 0x7fffffff)
    rng = np.random.RandomState(seed)
    key_seq = hk.PRNGSequence(rng.randint(0, 0x7fffffff))

    activation = jnp.tanh

    eta = 0.3
    assert eta < 4 / (5 + FLAGS.p)  # needed for IGS stability

    if FLAGS.dagger:
        intermediate_policy = training_utils.dagger_policy_with_expert
        final_policy = training_utils.dagger_final_policy
    else:
        intermediate_policy = training_utils.mixed_policy_with_expert
        final_policy = training_utils.final_policy

    # make dynamics and expert
    dynamics, expert_policy = problem_instance_utils.make_dynamics_and_expert(
        next(key_seq), FLAGS.state_dim, FLAGS.p, eta, activation)

    policy_net = training_utils.make_policy_net(64, FLAGS.state_dim,
                                                activation)

    opt_init, opt_update = optax.chain(
        # Set the parameters of Adam. Note the learning_rate is not here.
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        # Put a minus sign to *minimise* the loss.
        optax.scale(-FLAGS.learning_rate))

    aggregate_states, aggregate_actions = [], []
    policy_params = []  # accumulate all the networks trained at each epoch
    for shift_epoch in tqdm.trange(FLAGS.n_shift_epochs):
        logging.info('starting epoch %d', shift_epoch)

        start_time = time.time()
        if shift_epoch == 0:
            epoch_rollout_policy = expert_policy
        else:
            epoch_rollout_policy = functools.partial(intermediate_policy,
                                                     policy_net, expert_policy,
                                                     policy_params,
                                                     FLAGS.alpha)

        x0s_epoch = problem_instance_utils.sample_initial_conditions(
            next(key_seq), FLAGS.n_trajs_per_epoch, FLAGS.state_dim)
        Xs_epoch, Us_epoch = jax.vmap(problem_instance_utils.rollout_policy,
                                      in_axes=(None, None, 0,
                                               None))(dynamics,
                                                      epoch_rollout_policy,
                                                      x0s_epoch, FLAGS.horizon)
        Us_expert_labels = jax.vmap(lambda traj: jax.vmap(expert_policy)
                                    (traj[:-1]))(Xs_epoch)

        logging.info('rolling out %d trajectories took %f seconds',
                     FLAGS.n_trajs_per_epoch,
                     time.time() - start_time)

        # compute goal error
        logging.info('goal error: %s',
                     stats(np.linalg.norm(Xs_epoch[:, -1, :], axis=1)))
        # compute imitation error
        logging.info(
            'imitiation error: %s',
            stats(
                np.sum(np.linalg.norm(Us_epoch - Us_expert_labels, axis=2),
                       axis=1)))

        # format for training
        epoch_train_states = Xs_epoch[:, :-1, :].reshape(
            (-1, Xs_epoch.shape[-1]))
        epoch_train_actions = Us_expert_labels.reshape(
            (-1, Us_expert_labels.shape[-1]))

        # aggregate the accumulated data
        if FLAGS.aggregate_data:
            aggregate_states.append(epoch_train_states)
            aggregate_actions.append(epoch_train_actions)
            epoch_train_states = np.concatenate(aggregate_states, axis=0)
            epoch_train_actions = np.concatenate(aggregate_actions, axis=0)

        logging.info('epoch_train_states.shape: %s', epoch_train_states.shape)
        logging.info('epoch_train_actions.shape: %s',
                     epoch_train_actions.shape)
        assert epoch_train_states.shape[0] == epoch_train_actions.shape[0]
        assert epoch_train_states.shape[1] == FLAGS.state_dim
        assert epoch_train_actions.shape[1] == FLAGS.state_dim

        # initial parameters for training
        if shift_epoch == 0:
            params = policy_net.init(next(key_seq), epoch_train_states[0])
            trust_region_params = jax_utils.pytree_zeros_like(params)
        else:
            assert len(policy_params) >= 1
            params = policy_params[-1]
            trust_region_params = params

        if FLAGS.igs_constraint_lam > 0.0:
            if shift_epoch == FLAGS.n_shift_epochs - 1:

                def policy_fn(policy_network, this_policy_params, x):
                    return final_policy(policy_network,
                                        policy_params + [this_policy_params],
                                        FLAGS.alpha, x)
            else:

                def policy_fn(policy_network, this_policy_params, x):
                    return intermediate_policy(
                        policy_network, expert_policy,
                        policy_params + [this_policy_params], FLAGS.alpha, x)

            def igs_loss(x, y, fx, fy):
                # want |fx - fy| - |x - y| <= 0
                ineq = jnp.abs(fx - fy) - jnp.abs(x - y)
                return FLAGS.igs_constraint_lam * jnp.maximum(ineq, 0)

            igs_constraint_args = (dynamics, igs_loss, policy_fn)
        else:
            igs_constraint_args = None

        start_time = time.time()
        params, _, last_epoch_losses = training_utils.train_policy_network(
            policy_net,
            opt_update, epoch_train_states, epoch_train_actions, params,
            opt_init(params), trust_region_params, 0.0, igs_constraint_args,
            FLAGS.n_train_epochs, FLAGS.batch_size, 0.0, 1000, rng,
            FLAGS.verbose_learner)
        policy_params.append(params)
        logging.info(
            'shift_epoch=%d, last_epoch_losses=%s, '
            'avg_last_epoch_losses=%s', shift_epoch, last_epoch_losses,
            last_epoch_losses / len(epoch_train_states))
        logging.info('train_policy_network at epoch %d took %f seconds',
                     shift_epoch,
                     time.time() - start_time)

    logging.info('running final episodes')

    x0s_final_test = problem_instance_utils.sample_initial_conditions(
        next(key_seq), FLAGS.n_trajs_final_eval, FLAGS.state_dim)
    Xs_final_test_shift, Us_final_test_shift = jax.vmap(
        problem_instance_utils.rollout_policy,
        in_axes=(None, None, 0,
                 None))(dynamics,
                        functools.partial(final_policy, policy_net,
                                          policy_params, FLAGS.alpha),
                        x0s_final_test, FLAGS.horizon)
    Us_expert_final_test_shift = jax.vmap(lambda traj: jax.vmap(expert_policy)
                                          (traj[:-1]))(Xs_final_test_shift)

    Xs_final_test_exp, _ = jax.vmap(problem_instance_utils.rollout_policy,
                                    in_axes=(None, None, 0,
                                             None))(dynamics, expert_policy,
                                                    x0s_final_test,
                                                    FLAGS.horizon)

    final_test_shift = np.linalg.norm(Xs_final_test_shift[:, -1, :], axis=1)
    final_test_exp = np.linalg.norm(Xs_final_test_exp[:, -1, :], axis=1)
    final_test_delta_goal_error = np.linalg.norm(
        Xs_final_test_shift[:, -1, :] - Xs_final_test_exp[:, -1, :], axis=1)
    final_imitation_error = np.sum(np.linalg.norm(Us_final_test_shift -
                                                  Us_expert_final_test_shift,
                                                  axis=2),
                                   axis=1)

    logging.info('final shift goal error: %s', stats(final_test_shift))
    logging.info('expert goal error: %s', stats(final_test_exp))
    logging.info('final delta goal error: %s',
                 stats(final_test_delta_goal_error))
    logging.info('final_imitation_error: %s', stats(final_imitation_error))

    if FLAGS.metrics_outfile is not None:
        with open(FLAGS.metrics_outfile, 'wb') as fp:
            pickle.dump(
                {
                    'final_test_shift': final_test_shift,
                    'final_test_exp': final_test_exp,
                    'final_test_delta_goal_error': final_test_delta_goal_error,
                    'final_imitation_error': final_imitation_error,
                }, fp)
    if FLAGS.config_outfile is not None:
        with open(FLAGS.config_outfile, 'wb') as fp:
            pickle.dump(get_config(), fp)
    if FLAGS.params_outfile is not None:
        with open(FLAGS.params_outfile, 'wb') as fp:
            pickle.dump(
                {
                    'mixing_weight': FLAGS.alpha,
                    'dagger': False,
                    'policy_params': policy_params
                }, fp)
Ejemplo n.º 27
0
    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    params["sampler"] = nucleaus_sample
    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps),
        clip_by_global_norm(1),
        optax.scale_by_adam(),
        optax.additive_weight_decay(0),
        optax.scale(-1),
        optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
    )

    params["optimizer"] = opt

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
Ejemplo n.º 28
0
    def __init__(self,
                 player_id,
                 state_representation_size,
                 num_actions,
                 hidden_layers_sizes=128,
                 replay_buffer_capacity=10000,
                 batch_size=128,
                 replay_buffer_class=ReplayBuffer,
                 learning_rate=0.01,
                 update_target_network_every=1000,
                 learn_every=10,
                 discount_factor=1.0,
                 min_buffer_size_to_learn=1000,
                 epsilon_start=1.0,
                 epsilon_end=0.1,
                 epsilon_decay_duration=int(1e6),
                 optimizer_str="sgd",
                 loss_str="mse",
                 huber_loss_parameter=1.0):
        """Initialize the DQN agent."""

        # This call to locals() is used to store every argument used to initialize
        # the class instance, so it can be copied with no hyperparameter change.
        self._kwargs = locals()

        self.player_id = player_id
        self._num_actions = num_actions
        if isinstance(hidden_layers_sizes, int):
            hidden_layers_sizes = [hidden_layers_sizes]
        self._layer_sizes = hidden_layers_sizes
        self._batch_size = batch_size
        self._update_target_network_every = update_target_network_every
        self._learn_every = learn_every
        self._min_buffer_size_to_learn = min_buffer_size_to_learn
        self._discount_factor = discount_factor
        self.huber_loss_parameter = huber_loss_parameter

        self._epsilon_start = epsilon_start
        self._epsilon_end = epsilon_end
        self._epsilon_decay_duration = epsilon_decay_duration

        # TODO(author6) Allow for optional replay buffer config.
        if not isinstance(replay_buffer_capacity, int):
            raise ValueError("Replay buffer capacity not an integer.")
        self._replay_buffer = replay_buffer_class(replay_buffer_capacity)
        self._prev_timestep = None
        self._prev_action = None

        # Step counter to keep track of learning, eps decay and target network.
        self._step_counter = 0

        # Keep track of the last training loss achieved in an update step.
        self._last_loss_value = None

        # Create the Q-network instances

        def network(x):
            mlp = hk.nets.MLP(self._layer_sizes + [num_actions])
            return mlp(x)

        self.hk_network = hk.without_apply_rng(hk.transform(network))
        self.hk_network_apply = jax.jit(self.hk_network.apply)

        rng = jax.random.PRNGKey(42)
        x = jnp.ones([1, state_representation_size])
        self.params_q_network = self.hk_network.init(rng, x)
        self.params_target_q_network = self.hk_network.init(rng, x)

        if loss_str == "mse":
            self.loss_func = lambda x: jnp.mean(x**2)
        elif loss_str == "huber":
            # pylint: disable=g-long-lambda
            self.loss_func = lambda x: jnp.mean(
                rlax.huber_loss(x, self.huber_loss_parameter))
        else:
            raise ValueError("Not implemented, choose from 'mse', 'huber'.")
        if optimizer_str == "adam":
            opt_init, opt_update = optax.chain(
                optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
                optax.scale(learning_rate))
        elif optimizer_str == "sgd":
            opt_init, opt_update = optax.sgd(learning_rate)
        else:
            raise ValueError("Not implemented, choose from 'adam' and 'sgd'.")
        self._opt_update_fn = self._get_update_func(opt_update)
        self._opt_state = opt_init(self.params_q_network)
        self._loss_and_grad = jax.value_and_grad(self._loss, has_aux=False)
        self._jit_update = jax.jit(self.get_update())
Ejemplo n.º 29
0
def train(
    train_data_iterator: Iterator[types.LabelledData],
    test_data_iterator: Iterator[types.LabelledData],
    elbo_fun: hk.Transformed,
    learning_rate: float,
    checkpoint_dir: str,
    checkpoint_filename: str,
    checkpoint_every: int,
    test_every: int,
    iterations: int,
    rng_seed: int,
    test_functions: Optional[Sequence[Callable[[Mapping[str, jnp.ndarray]],
                                               Tuple[str, float]]]] = None,
    extra_checkpoint_info: Optional[Mapping[str, Any]] = None):
  """Train VAE with given data iterator and elbo definition.

  Args:
   train_data_iterator: Iterator of batched training data.
   test_data_iterator: Iterator of batched testing data.
   elbo_fun: Haiku transfomed function returning elbo.
   learning_rate: Learning rate to be used with optimizer.
   checkpoint_dir: Path of the checkpoint directory.
   checkpoint_filename: Filename of the checkpoint.
   checkpoint_every: Checkpoint every N iterations.
   test_every: Test and log results every N iterations.
   iterations: Number of training iterations to perform.
   rng_seed: Seed for random number generator.
   test_functions: Test function iterable, each function takes test data and
    outputs extra info to print at test and log time.
   extra_checkpoint_info: Extra info to put inside saved checkpoint.
  """
  rng_seq = hk.PRNGSequence(jax.random.PRNGKey(rng_seed))

  opt_init, opt_update = optax.chain(
      # Set the parameters of Adam. Note the learning_rate is not here.
      optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
      # Put a minus sign to *minimise* the loss.
      optax.scale(-learning_rate))

  @jax.jit
  def loss(params, key, data):
    elbo_outputs = elbo_fun.apply(params, key, data)
    return -jnp.mean(elbo_outputs.elbo)

  @jax.jit
  def loss_test(params, key, data):
    elbo_output = elbo_fun.apply(params, key, data)
    return (-jnp.mean(elbo_output.elbo), jnp.mean(elbo_output.data_fidelity),
            jnp.mean(elbo_output.kl))

  @jax.jit
  def update_step(params, key, data, opt_state):
    grads = jax.grad(loss, has_aux=False)(params, key, data)
    updates, opt_state = opt_update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state

  exp_checkpointer = checkpointer.Checkpointer(
      checkpoint_dir, checkpoint_filename)
  experiment_data = exp_checkpointer.load_checkpoint()

  if experiment_data is not None:
    start = experiment_data['step']
    params = experiment_data['experiment_state']
    opt_state = experiment_data['opt_state']
  else:
    start = 0
    params = elbo_fun.init(
        next(rng_seq), next(train_data_iterator).data)
    opt_state = opt_init(params)

  for step in range(start, iterations, 1):
    if step % test_every == 0:
      test_loss, ll, kl = loss_test(params, next(rng_seq),
                                    next(test_data_iterator).data)
      output_message = (f'Step {step} elbo {-test_loss:0.2f} LL {ll:0.2f} '
                        f'KL {kl:0.2f}')
      if test_functions:
        for test_function in test_functions:
          name, test_output = test_function(params)
          output_message += f' {name}: {test_output:0.2f}'
      print(output_message)

    params, opt_state = update_step(params, next(rng_seq),
                                    next(train_data_iterator).data, opt_state)

    if step % checkpoint_every == 0:
      exp_checkpointer.save_checkpoint(
          params, opt_state, step, extra_checkpoint_info)
Ejemplo n.º 30
0
 def update(
         self, gradient: Weights, state: GenericGradientState,
         parameters: Optional[Weights]
 ) -> Tuple[Weights, GenericGradientState]:
     return GenericGradientState.wrap(*scale_by_adam(
         **asdict(self)).update(gradient, state.data, parameters))