def get_loaders():

    svhn, mnist = get_datasets(is_training=True)
    val_svhn, val_mnist = get_datasets(is_training=False)

    train_dataset = svhn if DATA == 'svhn' else mnist
    weights = make_weights_for_balanced_classes(train_dataset, num_classes=10)
    sampler = WeightedRandomSampler(weights, len(weights))
    train_loader = DataLoader(train_dataset,
                              BATCH_SIZE,
                              sampler=sampler,
                              pin_memory=True,
                              drop_last=True)

    val_svhn_loader = DataLoader(val_svhn,
                                 BATCH_SIZE,
                                 shuffle=False,
                                 drop_last=False)
    val_mnist_loader = DataLoader(val_mnist,
                                  BATCH_SIZE,
                                  shuffle=False,
                                  drop_last=False)
    return train_loader, val_svhn_loader, val_mnist_loader
Beispiel #2
0
  def _get_datasets(self):
    config = default.get_config()
    config.per_device_batch_size = 1
    config.eval_per_device_batch_size = 2
    config.vocab_size = 32
    config.max_corpus_chars = 1000
    config.max_target_length = _TARGET_LENGTH
    config.max_eval_target_length = _EVAL_TARGET_LENGTH
    config.max_predict_length = _PREDICT_TARGET_LENGTH

    vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model')

    # Go two directories up to the root of the flax directory.
    flax_root_dir = pathlib.Path(__file__).parents[2]
    data_dir = str(flax_root_dir) + '/.tfds/metadata'  # pylint: disable=unused-variable

    with tfds.testing.mock_data(num_examples=128, data_dir=data_dir):
      train_ds, eval_ds, predict_ds, _ = input_pipeline.get_datasets(
          n_devices=2, config=config, vocab_path=vocab_path)
    return train_ds, eval_ds, predict_ds
Beispiel #3
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
        config.vocab_path = vocab_path
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, _, encoder = input_pipeline.get_datasets(
        n_devices=jax.local_device_count(),
        config=config,
        vocab_path=vocab_path)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = temperature_sampler.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    def encode_strings(strs, max_len):
        tokenized_batch = np.zeros((len(strs), max_len), np.int32)
        for i, s in enumerate(strs):
            toks = encoder.tokenize(s).numpy()
            # Remove EOS token in prompt.
            tokenized_batch[i, :toks.shape[0] - 1] = toks[:-1]
        return tokenized_batch

    tokenized_prompts = encode_strings([config.prompts],
                                       config.max_predict_length)

    logging.info("Initializing model, optimizer, and step functions.")
    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    rng, inference_rng = random.split(rng)
    input_shape = (config.per_device_batch_size, config.max_target_length)

    m = models.TransformerLM(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32))

    learning_rate_fn = create_learning_rate_schedule(
        learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

    optimizer = optax.adamw(learning_rate_fn,
                            b1=0.9,
                            b2=0.98,
                            eps=1e-9,
                            weight_decay=config.weight_decay)
    state = train_state.TrainState.create(apply_fn=m.apply,
                                          params=initial_variables["params"],
                                          tx=optimizer)
    # We access model params only from optimizer below.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        state = checkpoints.restore_checkpoint(workdir, state)
        # Grab last step.
        start_step = int(state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if start_step == 0:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    state = jax_utils.replicate(state)

    # compile multidevice versions of train/eval/predict step fn.
    p_train_step = jax.pmap(functools.partial(
        train_step, config=train_config, learning_rate_fn=learning_rate_fn),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name="batch")

    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          temperature=config.sampling_temperature,
                          top_k=config.sampling_top_k),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
        ]
    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, config.num_train_steps):
            is_last_step = step == config.num_train_steps - 1

            # Shard data to devices and do a training step.
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = common_utils.shard(
                    jax.tree_map(np.asarray, next(train_iter)))
                state, metrics = p_train_step(state,
                                              batch,
                                              dropout_rng=dropout_rngs)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if step % config.eval_every_steps == 0 or is_last_step:
                with report_progress.timed("training_metrics"):
                    logging.info("Gathering training metrics.")
                    train_metrics = common_utils.get_metrics(train_metrics)
                    lr = train_metrics.pop("learning_rate").mean()
                    metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                    denominator = metrics_sums.pop("denominator")
                    summary = jax.tree_map(lambda x: x / denominator,
                                           metrics_sums)  # pylint: disable=cell-var-from-loop
                    summary["learning_rate"] = lr
                    summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]),
                                                     a_max=1.0e4)
                    summary = {"train_" + k: v for k, v in summary.items()}
                    writer.write_scalars(step, summary)
                    train_metrics = []

                with report_progress.timed("eval"):
                    eval_results = evaluate(
                        p_eval_step=p_eval_step,
                        params=state.params,
                        eval_ds=eval_ds,
                        num_eval_steps=config.num_eval_steps)
                    # (clipped) perplexity after averaging log-perplexitie
                    eval_results["perplexity"] = jnp.clip(jnp.exp(
                        eval_results["loss"]),
                                                          a_max=1.0e4)
                    writer.write_scalars(
                        step,
                        {"eval_" + k: v
                         for k, v in eval_results.items()})

                with report_progress.timed("generate_text"):
                    exemplars = generate_prediction(
                        p_pred_step=p_pred_step,
                        params=state.params,
                        tokenized_prompts=tokenized_prompts,
                        eos_id=eos_id,
                        inference_rng=inference_rng,
                        decode_tokens=decode_tokens,
                        max_predict_length=config.max_predict_length)
                    writer.write_texts(step, {"samples": exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = (step % config.checkpoint_every_steps == 0
                               or is_last_step)
            if config.save_checkpoints and save_checkpoint and jax.process_index(
            ) == 0:
                with report_progress.timed("checkpoint"):
                    checkpoints.save_checkpoint(workdir,
                                                jax_utils.unreplicate(state),
                                                step)
def train_and_evaluate():

    svhn, mnist = get_datasets(is_training=True)
    source_dataset = svhn if SOURCE_DATA == 'svhn' else mnist
    target_dataset = mnist if SOURCE_DATA == 'svhn' else svhn

    weights = make_weights_for_balanced_classes(source_dataset, num_classes=10)
    sampler = WeightedRandomSampler(weights, len(weights))
    source_loader = DataLoader(source_dataset,
                               BATCH_SIZE,
                               sampler=sampler,
                               pin_memory=True,
                               drop_last=True)
    target_loader = DataLoader(target_dataset,
                               BATCH_SIZE,
                               shuffle=True,
                               pin_memory=True,
                               drop_last=True)

    val_svhn, val_mnist = get_datasets(is_training=False)
    val_svhn_loader = DataLoader(val_svhn,
                                 BATCH_SIZE,
                                 shuffle=False,
                                 drop_last=False)
    val_mnist_loader = DataLoader(val_mnist,
                                  BATCH_SIZE,
                                  shuffle=False,
                                  drop_last=False)
    print('\nsource dataset is', SOURCE_DATA, '\n')

    num_steps_per_epoch = math.floor(min(len(svhn), len(mnist)) / BATCH_SIZE)
    embedder = Network(image_size=(32, 32),
                       embedding_dim=EMBEDDING_DIM).to(DEVICE)
    classifier = nn.Linear(EMBEDDING_DIM, 10).to(DEVICE)
    model = nn.Sequential(embedder, classifier)
    model.train()

    optimizer = optim.Adam(lr=1e-3,
                           params=model.parameters(),
                           weight_decay=1e-3)
    scheduler = CosineAnnealingLR(optimizer,
                                  T_max=num_steps_per_epoch * NUM_EPOCHS -
                                  DELAY,
                                  eta_min=1e-6)

    cross_entropy = nn.CrossEntropyLoss()
    association = WalkerVisitLosses()

    text = 'e:{0:2d}, i:{1:3d}, classification loss: {2:.3f}, ' +\
        'walker loss: {3:.3f}, visit loss: {4:.4f}, ' +\
        'total loss: {5:.3f}, lr: {6:.6f}'
    logs, val_logs = [], []
    i = 0  # iteration

    for e in range(NUM_EPOCHS):
        model.train()
        for (x_source, y_source), (x_target, _) in zip(source_loader,
                                                       target_loader):

            x_source = x_source.to(DEVICE)
            x_target = x_target.to(DEVICE)
            y_source = y_source.to(DEVICE)

            x = torch.cat([x_source, x_target], dim=0)
            embeddings = embedder(x)
            a, b = torch.split(embeddings, BATCH_SIZE, dim=0)
            logits = classifier(a)
            usual_loss = cross_entropy(logits, y_source)
            walker_loss, visit_loss = association(a, b, y_source)

            if i > DELAY:
                growth = torch.clamp(
                    torch.tensor((i - DELAY) / GROWTH_STEPS).to(DEVICE), 0.0,
                    1.0)
                loss = usual_loss + growth * (BETA1 * walker_loss +
                                              BETA2 * visit_loss)
            else:
                loss = usual_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i > DELAY:
                scheduler.step()
            lr = scheduler.get_lr()[0]

            log = (e, i, usual_loss.item(), walker_loss.item(),
                   visit_loss.item(), loss.item(), lr)
            print(text.format(*log))
            logs.append(log)
            i += 1

        result1 = evaluate(model, cross_entropy, val_svhn_loader, DEVICE)
        result2 = evaluate(model, cross_entropy, val_mnist_loader, DEVICE)
        print('\nsvhn loss {0:.3f} and accuracy {1:.3f}'.format(*result1))
        print('mnist loss {0:.3f} and accuracy {1:.3f}\n'.format(*result2))
        val_logs.append((i, ) + result1 + result2)

    torch.save(model.state_dict(), SAVE_PATH)
    write_logs(logs, val_logs, LOGS_PATH)
Beispiel #5
0
import sys
sys.path.append("./long-range-arena/lra_benchmarks/listops/")
import input_pipeline
import numpy as np
import pickle

train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets(
    n_devices=1,
    task_name="basic",
    data_dir="./lra_release/lra_release/listops-1000/",
    batch_size=1,
    max_length=2000)

mapping = {"train": train_ds, "dev": eval_ds, "test": test_ds}
for component in mapping:
    ds_list = []
    for idx, inst in enumerate(iter(mapping[component])):
        ds_list.append({
            "input_ids_0":
            np.concatenate(
                [inst["inputs"].numpy()[0],
                 np.zeros(48, dtype=np.int32)]),
            "label":
            inst["targets"].numpy()[0]
        })
        if idx % 100 == 0:
            print(f"{idx}\t\t", end="\r")
    with open(f"listops.{component}.pickle", "wb") as f:
        pickle.dump(ds_list, f)