示例#1
0
    def test_named(self, wait_jax_async_dispatch, mock_time):
        mock_time.return_value = 0
        hook = periodic_actions.ReportProgress(every_steps=1,
                                               every_secs=None,
                                               num_train_steps=10)

        def _wait():
            # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1)
            hook._executor.submit(lambda: None).result()

        self.assertFalse(hook(1))  # Never triggers on first execution.
        with hook.timed("test1", wait_jax_async_dispatch):
            _wait()
            mock_time.return_value = 1
        _wait()
        with hook.timed("test2", wait_jax_async_dispatch):
            _wait()
            mock_time.return_value = 2
        _wait()
        with hook.timed("test1", wait_jax_async_dispatch):
            _wait()
            mock_time.return_value = 3
        _wait()
        mock_time.return_value = 4
        with self.assertLogs(level="INFO") as logs:
            self.assertTrue(hook(2))
        self.assertEqual(logs.output, [
            "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA: 0m"
            " (0m : 50.0% test1, 25.0% test2)"
        ])
示例#2
0
    def test_named(self, wait_jax_async_dispatch, time_mock):
        time_mock.return_value = 0
        hook = periodic_actions.ReportProgress(every_steps=1,
                                               every_secs=None,
                                               num_train_steps=10)

        def _wait():
            # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1)
            hook._executor.submit(lambda: None).result()

        hook(1)
        with hook.timed("test1", wait_jax_async_dispatch):
            _wait()
            time_mock.return_value = 1
        _wait()
        with hook.timed("test2", wait_jax_async_dispatch):
            _wait()
            time_mock.return_value = 2
        _wait()
        with hook.timed("test1", wait_jax_async_dispatch):
            _wait()
            time_mock.return_value = 3
        _wait()
        time_mock.return_value = 4
        with self.assertLogs(level="INFO") as logs:
            hook(2)
        self.assertEqual(logs.output, [
            "INFO:absl:Setting work unit notes: 20.0% @2, 0.2 steps/s, ETA: 1 min"
            " (0 min : 50.0% test1, 25.0% test2)"
        ])
示例#3
0
 def test_called_every_step(self):
     hook = periodic_actions.ReportProgress(every_steps=3,
                                            num_train_steps=10)
     t = time.time()
     with self.assertRaisesRegex(
             ValueError, "PeriodicAction must be called after every step"):
         hook(1, t)
         hook(11, t)  # Raises exception.
示例#4
0
 def test_called_every_step(self):
     hook = periodic_actions.ReportProgress(every_steps=3,
                                            num_train_steps=10)
     t = time.time()
     with self.assertRaisesRegex(
             ValueError, "EveryNHook must be called after every step"):
         hook(1, t)
         # Skipping step 2.
         hook(11, t)
示例#5
0
 def test_without_num_train_steps(self):
     report = periodic_actions.ReportProgress(every_steps=2)
     t = time.time()
     with self.assertLogs(level="INFO") as logs:
         self.assertFalse(report(1, t))
         self.assertTrue(report(2, t + 0.12))
     # We did 1 step in 0.12s => 8.333 steps/s.
     self.assertEqual(logs.output,
                      ["INFO:absl:Setting work unit notes: 8.3 steps/s"])
示例#6
0
 def test_unknown_cardinality(self):
     report = periodic_actions.ReportProgress(
         every_steps=2, num_train_steps=tf.data.UNKNOWN_CARDINALITY)
     t = time.time()
     with self.assertLogs(level="INFO") as logs:
         self.assertFalse(report(1, t))
         self.assertTrue(report(2, t + 0.12))
     # We did 1 step in 0.12s => 8.333 steps/s.
     self.assertEqual(logs.output,
                      ["INFO:absl:Setting work unit notes: 8.3 steps/s"])
示例#7
0
 def test_every_steps(self):
     hook = periodic_actions.ReportProgress(every_steps=4,
                                            every_secs=None,
                                            num_train_steps=10)
     t = time.time()
     with self.assertLogs(level="INFO") as logs:
         hook(1, t)
         t += 0.11
         hook(2, t)
         t += 0.13
         hook(3, t)
         t += 0.12
         hook(4, t)
     # We did 1 step every 0.12s => 8.333 steps/s.
     self.assertEqual(logs.output, [
         "INFO:absl:Setting work unit notes: 40.0% @4, 8.3 steps/s, ETA: 0 min"
     ])
示例#8
0
 def test_every_secs(self):
     hook = periodic_actions.ReportProgress(every_steps=None,
                                            every_secs=0.3,
                                            num_train_steps=10)
     t = time.time()
     with self.assertLogs(level="INFO") as logs:
         self.assertFalse(hook(1, t))
         t += 0.11
         self.assertFalse(hook(2, t))
         t += 0.13
         self.assertFalse(hook(3, t))
         t += 0.12
         self.assertTrue(hook(4, t))
     # We did 1 step every 0.12s => 8.333 steps/s.
     self.assertEqual(logs.output, [
         "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m"
     ])
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    if config.dataset.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")

    tf.io.gfile.makedirs(workdir)
    # Deterministic training.
    rng = jax.random.PRNGKey(config.seed)
    # Shift the numpy random seed by process_index() to shuffle data loaded
    # by different hosts
    np.random.seed(20201473 + jax.process_index())

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    config.dataset.data_dir = os.path.join(config.dataset.base_dir,
                                           config.dataset.scene)
    train_ds, eval_ds = datasets.create_dataset(config)
    example_batch = train_ds.peek()

    #----------------------------------------------------------------------------
    # Learning rate schedule.
    num_train_steps = config.train.max_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.size()
    steps_per_epoch = num_train_steps // config.train.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    learning_rate_fn = train_utils.create_learning_rate_fn(config)

    #----------------------------------------------------------------------------
    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = models.create_train_state(
        config,
        model_rng,
        learning_rate_fn=learning_rate_fn,
        example_batch=example_batch,
    )

    #----------------------------------------------------------------------------
    # Set up checkpointing of the model and the input pipeline.
    state = checkpoints.restore_checkpoint(workdir, state)
    initial_step = int(state.step) + 1

    #----------------------------------------------------------------------------
    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model=model,
            learning_rate_fn=learning_rate_fn,
            weight_decay=config.train.weight_decay,
            config=config,
        ),
        axis_name="batch",
    )

    # Get distributed rendering function
    render_pfn = render_utils.get_render_function(
        model=model,
        config=config,
        randomized=False,  # No randomization for evaluation.
    )

    #----------------------------------------------------------------------------
    # Prepare Metric Writers
    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
        ]
    train_metrics = None

    # Prefetch_buffer_size = 6 x batch_size
    ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = jax.random.split(rng, n_local_devices)  # For pmapping RNG keys.

    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = next(ptrain_ds)
                state, metrics_update, keys = p_train_step(rng=keys,
                                                           state=state,
                                                           batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))
            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.train.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            if step % config.train.render_every_steps == 0 or is_last_step:
                test_batch = next(eval_ds)
                test_pixels = model_utils.uint2float(
                    test_batch.target_view.rgb)  # extract for evaluation
                with report_progress.timed("eval"):
                    pred_color, pred_disp, pred_acc = eval_step(
                        state, keys[0], test_batch, render_pfn, config)
                #------------------------------------------------------------------
                # Log metrics and images for host 0
                #------------------------------------------------------------------
                if jax.process_index() == 0:
                    psnr = model_utils.compute_psnr(
                        ((pred_color - test_pixels)**2).mean())
                    ssim = skmetrics.structural_similarity(
                        pred_color.astype(np.float32),
                        test_pixels.astype(np.float32),
                        win_size=11,
                        multichannel=True,
                        gaussian_weight=True)
                    writer.write_scalars(
                        step, {
                            "train_eval/test_psnr": psnr,
                            "train_eval/test_ssim": ssim,
                        })
                    writer.write_images(
                        step, {
                            "test_pred_color": pred_color[None, :],
                            "test_target": test_pixels[None, :]
                        })
                    if pred_disp is not None:
                        writer.write_images(
                            step, {"test_pred_disp": pred_disp[None, :]})
                    if pred_acc is not None:
                        writer.write_images(
                            step, {"test_pred_acc": pred_acc[None, :]})
                #------------------------------------------------------------------

            if (jax.process_index()
                    == 0) and (step % config.train.checkpoint_every_steps == 0
                               or is_last_step):
                # Write final metrics to file
                with file_utils.open_file(
                        os.path.join(workdir, "train_logs.json"), "w") as f:
                    log_dict = metric_update.compute()
                    for k, v in log_dict.items():
                        log_dict[k] = v.item()
                    f.write(json.dumps(log_dict))
                with report_progress.timed("checkpoint"):
                    state_to_save = jax.device_get(
                        jax.tree_map(lambda x: x[0], state))
                    checkpoints.save_checkpoint(workdir,
                                                state_to_save,
                                                step,
                                                keep=100)

    logging.info("Finishing training at step %d", num_train_steps)
示例#10
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.process_index() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.process_index(),
        shard_count=jax.process_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_to_keep=FLAGS.data_selection_size,
        pseudo_path=FLAGS.pseudo_path,
        repeat_count=FLAGS.repeat_count,
        newscommentary_size=FLAGS.newscommentary_size,
        split_tokenizer=FLAGS.split_tokenizer)

    if FLAGS.aux_eval_dataset:
        aux_datasets = []
        aux_names = FLAGS.aux_eval_dataset.split(',')
        for name in aux_names:
            _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
                dataset_name=name,
                eval_dataset_name=None,
                shard_idx=jax.process_index(),
                shard_count=jax.process_count(),
                data_dir=FLAGS.data_dir,
                vocab_path=vocab_path,
                target_vocab_size=FLAGS.vocab_size,
                batch_size=FLAGS.batch_size,
                max_length=FLAGS.max_target_length,
                max_eval_length=FLAGS.max_eval_target_length,
                paracrawl_size=FLAGS.paracrawl_size,
                is_scores_path=FLAGS.is_scores_path,
                num_to_keep=FLAGS.data_selection_size,
                pseudo_path=FLAGS.pseudo_path,
                repeat_count=FLAGS.repeat_count,
                newscommentary_size=FLAGS.newscommentary_size)
            aux_datasets.append(aux_eval_ds)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = jax.random.split(rng)
    # It's possible that is supposed to be per device batch size
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        logging.info('Restoring checkpoint.')
        # If we have a pretrained model, use that. Else, just continue where leftoff
        model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
        optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)

    flag_key = [
        k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
    ]
    if flag_key:
        flag_key = flag_key[0]
        local_flags = {
            f.name: f.value
            for f in FLAGS.flags_by_module_dict()[flag_key]
        }
        writer.write_hparams(local_flags)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = common.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps,
        steps_per_cycle=FLAGS.steps_per_cycle,
        init_step=start_step,
        finetune_lr=FLAGS.finetune_lr)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_util.train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing),
                            axis_name='batch',
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(train_util.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_util.initialize_cache,
        max_decode_len=FLAGS.max_predict_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(train_util.predict_step,
                          config=predict_config,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info('Starting training loop.')
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=FLAGS.model_dir,
                                     num_profile_steps=5)
        ]
    train_metrics = []
    total_steps = start_step + FLAGS.num_train_steps
    if FLAGS.eval_only:
        total_steps = start_step + 1
    best_eval_loss = 1000
    curr_eval_loss = 1000
    eval_loss_history = []
    last_eval_step = 0
    do_resample_data = False
    gradual_selection_size = FLAGS.data_selection_size
    dynamic_eval_freq = FLAGS.eval_frequency
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, total_steps):
            is_last_step = step == total_steps - 1

            # Resample training data for gradual FT
            if do_resample_data:
                # resample data
                do_resample_data = False
                gradual_selection_size *= .7
                dynamic_eval_freq = int(gradual_selection_size / 1000 / 4)

                train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
                    dataset_name=FLAGS.dataset_name,
                    eval_dataset_name=FLAGS.eval_dataset_name,
                    shard_idx=jax.process_index(),
                    shard_count=jax.process_count(),
                    data_dir=FLAGS.data_dir,
                    vocab_path=vocab_path,
                    target_vocab_size=FLAGS.vocab_size,
                    batch_size=FLAGS.batch_size,
                    max_length=FLAGS.max_target_length,
                    max_eval_length=FLAGS.max_eval_target_length,
                    paracrawl_size=FLAGS.paracrawl_size,
                    is_scores_path=FLAGS.is_scores_path,
                    num_to_keep=int(gradual_selection_size),
                    pseudo_path=FLAGS.pseudo_path,
                    repeat_count=FLAGS.repeat_count,
                    newscommentary_size=FLAGS.newscommentary_size,
                    split_tokenizer=FLAGS.split_tokenizer)
                train_iter = iter(train_ds)

            # Shard data to devices and do a training step.
            if not FLAGS.eval_only:
                logging.info('Doing Training.')
                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    try:
                        batch = common_utils.shard(
                            jax.tree_map(np.asarray, next(train_iter)))
                        optimizer, metrics = p_train_step(
                            optimizer, batch, dropout_rng=dropout_rngs)
                        train_metrics.append(metrics)
                    except StopIteration:
                        is_last_step = True

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if (step - start_step) % dynamic_eval_freq == 0 or is_last_step:
                if not FLAGS.eval_only:
                    with report_progress.timed('training_metrics'):
                        logging.info('Gathering training metrics.')
                        train_metrics = common_utils.get_metrics(train_metrics)
                        lr = train_metrics.pop('learning_rate').mean()
                        metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                        denominator = metrics_sums.pop('denominator')
                        summary = jax.tree_map(lambda x: x / denominator,
                                               metrics_sums)  # pylint: disable=cell-var-from-loop
                        summary['learning_rate'] = lr
                        summary = {'train_' + k: v for k, v in summary.items()}
                        writer.write_scalars(step, summary)
                        train_metrics = []

                if FLAGS.eval_only:
                    p_eval_per_pos_step = jax.pmap(functools.partial(
                        train_util.eval_per_pos_step, config=eval_config),
                                                   axis_name='batch')
                    # Get per example loss
                    loss_filename = FLAGS.model_dir + '/test_losses.csv'
                    train_util.write_per_example_losses(
                        p_eval_step=p_eval_per_pos_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=FLAGS.num_eval_steps,
                        loss_filename=loss_filename)
                else:
                    with report_progress.timed('eval'):
                        eval_results = train_util.evaluate(
                            p_eval_step=p_eval_step,
                            target=optimizer.target,
                            eval_ds=eval_ds,
                            num_eval_steps=FLAGS.num_eval_steps)
                        curr_eval_loss = eval_results['loss']
                        eval_loss_history.append(curr_eval_loss)
                        if len(eval_loss_history) > 1:
                            improvement_rate = 0.000004
                            orig_loss = eval_loss_history[-2]
                            true_improvement = orig_loss - curr_eval_loss
                            expected_improvement = (
                                step - last_eval_step) * improvement_rate
                            # percent_change = (orig_loss - curr_eval_loss) / orig_loss
                            # percent_change *= 100
                            if true_improvement < expected_improvement:  # percent_change<.1:
                                do_resample_data = True
                        last_eval_step = step
                        writer.write_scalars(
                            step,
                            {'eval_' + k: v
                             for k, v in eval_results.items()})

                if FLAGS.aux_eval_dataset:
                    for aux_i, aux_eval_ds in enumerate(aux_datasets):
                        with report_progress.timed('aux_eval'):
                            eval_results = train_util.evaluate(
                                p_eval_step=p_eval_step,
                                target=optimizer.target,
                                eval_ds=aux_eval_ds,
                                num_eval_steps=FLAGS.num_eval_steps)
                            writer.write_scalars(
                                step, {
                                    'aux' + str(aux_i) + '_eval_' + k: v
                                    for k, v in eval_results.items()
                                })

                if FLAGS.compute_bleu:
                    with report_progress.timed('translate_and_bleu'):
                        decode_file = FLAGS.model_dir + '/decodes.csv'
                        exemplars, bleu_score = train_util.translate_and_calculate_bleu(
                            p_pred_step=p_pred_step,
                            p_init_cache=p_init_cache,
                            target=optimizer.target,
                            predict_ds=predict_ds,
                            decode_tokens=decode_tokens,
                            max_predict_length=FLAGS.max_predict_length,
                            num_eval_steps=FLAGS.num_eval_steps,
                            decode_file=decode_file if FLAGS.eval_only else '')
                        writer.write_scalars(step, {'bleu': bleu_score})
                        writer.write_texts(step, {'samples': exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0
                               or is_last_step)
            if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(
            ) == 0:
                if curr_eval_loss < best_eval_loss:  # only save better checkpoints
                    best_eval_loss = curr_eval_loss
                    with report_progress.timed('checkpoint'):
                        checkpoints.save_checkpoint(
                            FLAGS.model_dir,
                            jax_utils.unreplicate(optimizer),
                            step,
                            keep=FLAGS.chkpts_to_keep,
                            overwrite=True)

            if is_last_step:
                break
示例#11
0
    def train_and_evaluate(self, workdir):
        """Runs a training and evaluation loop.

    Args:
      workdir: Working directory for checkpoints and TF summaries. If this
        contains checkpoint training will be resumed from the latest checkpoint.
    """

        tf.io.gfile.makedirs(workdir)
        config = self.config
        substeps = config.training.substeps

        # Learning rate schedule.
        num_train_steps = config.training.num_train_steps
        logging.info('num_train_steps=%d', num_train_steps)

        # Get train state
        state = self._train_state

        # Set up checkpointing of the model and the input pipeline.
        checkpoint_dir = os.path.join(workdir, 'checkpoints')
        ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5)
        state = ckpt.restore_or_initialize(state)
        initial_step = int(state.step)

        # Distribute training.
        state = flax_utils.replicate(state)

        writer = metric_writers.create_default_writer(
            workdir, just_logging=jax.process_index() > 0)
        if initial_step == 0:
            writer.write_hparams(dict(config))

        logging.info('Starting training loop at step %d.', initial_step)
        hooks = []
        report_progress = periodic_actions.ReportProgress(
            num_train_steps=num_train_steps, writer=writer)
        if jax.process_index() == 0:
            hooks += [
                report_progress,
                periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
            ]
        step = initial_step
        with metric_writers.ensure_flushes(writer):
            while step < num_train_steps:
                # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
                # devices.
                is_last_step = step + substeps >= num_train_steps

                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    inputs = jax.tree_map(np.asarray, next(self._train_iter))
                    state, outputs = self._update_func(state, inputs)

                # Quick indication that training is happening.
                logging.log_first_n(logging.INFO, 'Finished training step %d.',
                                    5, step)
                for h in hooks:
                    h(step)

                new_step = int(state.step[0])
                assert new_step == step + substeps
                step = new_step

                is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step
                if step % config.logs.log_loss_every_steps == 0 and not is_eval:

                    def avg_over_substeps(x):
                        assert x.shape[0] == substeps
                        return float(x.mean(axis=0))

                    # Extract scalars and images.
                    outputs = flax_utils.unreplicate(outputs)
                    outputs = jax.tree_map(avg_over_substeps, outputs)
                    scalars = outputs['scalars']
                    writer.write_scalars(step, scalars)

                if is_eval:
                    with report_progress.timed('eval_full'):
                        outputs = self._eval_epoch(params=state.ema_params)
                        outputs = flax_utils.unreplicate(outputs)
                        scalars = outputs['scalars']
                        writer.write_scalars(step, scalars)

                if step % config.logs.checkpoint_every_steps == 0 or is_last_step:
                    with report_progress.timed('checkpoint'):
                        ckpt.save(flax_utils.unreplicate(state))

        logging.info('Finishing training at step %d', num_train_steps)
示例#12
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
    config.vocab_path = vocab_path
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
      n_devices=jax.local_device_count(),
      config=config,
      reverse_translation=config.reverse_translation,
      vocab_path=vocab_path)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  if config.num_predict_steps > 0:
    predict_ds = predict_ds.take(config.num_predict_steps)

  logging.info("Initializing model, optimizer, and step functions.")

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=config.share_embeddings,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  input_shape = (config.per_device_batch_size, config.max_target_length)
  target_shape = (config.per_device_batch_size, config.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  if start_step == 0:
    writer.write_hparams(dict(config))

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=config.label_smoothing),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config),
      axis_name="batch")
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=config.max_predict_length,
          config=predict_config),
      axis_name="batch")
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=config.beam_size),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info("Starting training loop.")
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if jax.host_id() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation("train", step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        optimizer, metrics = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed("training_metrics"):
          logging.info("Gathering training metrics.")
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop("learning_rate").mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop("denominator")
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary["learning_rate"] = lr
          summary = {"train_" + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed("eval"):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=config.num_eval_steps)
          writer.write_scalars(
              step, {"eval_" + k: v for k, v in eval_results.items()})

        with report_progress.timed("translate_and_bleu"):
          exemplars, bleu_score = translate_and_calculate_bleu(
              p_pred_step=p_pred_step,
              p_init_cache=p_init_cache,
              target=optimizer.target,
              predict_ds=predict_ds,
              decode_tokens=decode_tokens,
              max_predict_length=config.max_predict_length)
          writer.write_scalars(step, {"bleu": bleu_score})
          writer.write_texts(step, {"samples": exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (step % config.checkpoint_every_steps == 0 or
                         is_last_step)
      if config.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        with report_progress.timed("checkpoint"):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                      step)
示例#13
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    logging.info('Starting training at %s', workdir)
    tf.io.gfile.makedirs(workdir)
    if jax.process_index() == 0:
        with tf.io.gfile.GFile(os.path.join(workdir, 'config.json'), 'w') as f:
            json.dump(config.to_dict(), f, indent=2)
    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    train_ds, eval_ds = input_pipeline.create_datasets(config.dataset,
                                                       data_rng)
    train_iter = iter(train_ds)

    test_ds = []
    for split in config.dataset.test_splits:
        ds = input_pipeline.create_val_dataset(
            config.dataset, split, config.dataset.test_per_device_batch_size,
            config.dataset.test_pad_last_batch)
        test_ds.append(ds)

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.dataset.num_epochs
    logging.info('num_train_steps=%d, steps_per_epoch=%d', num_train_steps,
                 steps_per_epoch)
    learning_rate_fn = functools.partial(
        train_utils.get_learning_rate,
        base_learning_rate=config.learning_rate,
        num_train_steps=num_train_steps,
        schedule_type=config.learning_rate_schedule,
        warmup_proportion=config.warmup_proportion,
        step_boundaries=config.learning_rate_step_boundaries)

    # Initialize model.
    inputs = train_utils.get_init_inputs(train_ds)
    rng, model_rng = jax.random.split(rng)
    eval_config = models.TransformerConfig(**config.model.to_dict())
    train_config = eval_config.replace(deterministic=False)
    model = models.Model(eval_config)
    state = train_utils.create_train_state(model,
                                           config,
                                           model_rng,
                                           inputs=inputs)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        grad_clip=config.grad_clip),
                            axis_name='batch',
                            donate_argnums=(0, ))
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name='batch')

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(train_utils.flatten_config(config))

    logging.info('Starting training loop at step %d.', initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(
                num_profile_steps=config.num_profile_steps, logdir=workdir)
        ]

    rng, train_rngs = jax.random.split(rng)
    train_rngs = jax.random.fold_in(train_rngs, jax.process_index())
    train_rngs = jax.random.split(train_rngs, jax.local_device_count())

    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceContext('train', step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics = p_train_step(batch=batch,
                                              rng=train_rngs,
                                              state=state)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            if config.log_loss_every_steps > 0 and (
                    step % config.log_loss_every_steps == 0 or is_last_step):
                train_metrics = common_utils.get_metrics(train_metrics)
                lr = train_metrics.pop('learning_rate').mean()
                train_summary = train_utils.metrics_summary(
                    train_metrics, 'train')
                train_summary['learning_rate'] = lr
                writer.write_scalars(step, train_summary)
                train_metrics = []

            if config.eval_every_steps > 0 and (step % config.eval_every_steps
                                                == 0 or is_last_step):
                with report_progress.timed('eval'):
                    eval_summary = evaluate(p_eval_step, state, eval_ds,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_summary)

            if config.checkpoint_every_steps > 0 and (
                    step % config.checkpoint_every_steps == 0 or is_last_step):
                with report_progress.timed('checkpoint'):
                    ckpt.save(flax_utils.unreplicate(state))
                logging.info('Checkpoint saved to %s', checkpoint_dir)

    logging.info('Finishing training at step %d', num_train_steps)
示例#14
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs training interleaved with evaluation."""

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train')

  ds_train, ds_test = input_pipeline.get_datasets(config)
  batch = next(iter(ds_train))
  logging.info(ds_train)
  logging.info(ds_test)

  # Build VisionTransformer architecture
  model_cls = {'ViT': models.VisionTransformer,
               'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
  model = model_cls(num_classes=dataset_info['num_classes'], **config.model)

  def init_model():
    return model.init(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name),
        train=False)

  # Use JIT to make sure params reside in CPU memory.
  variables = jax.jit(init_model, backend='cpu')()

  model_or_filename = config.get('model_or_filename')
  if model_or_filename:
    # Loading model from repo published with  "How to train your ViT? Data,
    # Augmentation, and Regularization in Vision Transformers" paper.
    # https://arxiv.org/abs/2106.10270
    if '-' in model_or_filename:
      filename = model_or_filename
    else:
      # Select best checkpoint from i21k pretraining by final upstream
      # validation accuracy.
      df = checkpoint.get_augreg_df(directory=config.pretrained_dir)
      sel = df.filename.apply(
          lambda filename: filename.split('-')[0] == model_or_filename)
      best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1]
      filename = best.filename
      logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
                   filename, model_or_filename, best.final_val)
    pretrained_path = os.path.join(config.pretrained_dir,
                                   f'{config.model.name}.npz')
  else:
    # ViT / Mixer papers
    filename = config.model.name

  pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz')
  if not tf.io.gfile.exists(pretrained_path):
    raise ValueError(
        f'Could not find "{pretrained_path}" - you can download models from '
        '"gs://vit_models/imagenet21k" or directly set '
        '--config.pretrained_dir="gs://vit_models/imagenet21k".')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=variables['params'],
      model_config=config.model)

  total_steps = config.total_steps
  lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
                                              config.decay_type,
                                              config.warmup_steps)

  update_fn_repl = make_update_fn(
      apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn)
  infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=config.optim_dtype,
      grad_norm_clip=config.grad_norm_clip).create(params)

  initial_step = 1
  opt, initial_step = flax_checkpoints.restore_checkpoint(
      workdir, (opt, initial_step))
  logging.info('Will start/continue training at initial_step=%d', initial_step)

  opt_repl = flax.jax_utils.replicate(opt)

  # Delete references to the objects that are not needed anymore
  del opt
  del params

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))

  # Setup metric writer & hooks.
  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())
  hooks = [
      periodic_actions.Profile(logdir=workdir),
      periodic_actions.ReportProgress(
          num_train_steps=total_steps, writer=writer),
  ]

  # Run training loop
  logging.info('Starting training loop; initial compile can take a while...')
  t0 = lt0 = time.time()
  lstep = initial_step
  for step, batch in zip(
      range(initial_step, total_steps + 1),
      input_pipeline.prefetch(ds_train, config.prefetch)):

    with jax.profiler.StepTraceContext('train', step_num=step):
      opt_repl, loss_repl, update_rng_repl = update_fn_repl(
          opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)

    for hook in hooks:
      hook(step)

    if step == initial_step:
      logging.info('First step took %.1f seconds.', time.time() - t0)
      t0 = time.time()
      lt0, lstep = time.time(), step

    # Report training metrics
    if config.progress_every and step % config.progress_every == 0:
      img_sec_core_train = (config.batch * (step - lstep) /
                            (time.time() - lt0)) / jax.device_count()
      lt0, lstep = time.time(), step
      writer.write_scalars(
          step,
          dict(
              train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
              img_sec_core_train=img_sec_core_train))
      done = step / total_steps
      logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '  # pylint: disable=logging-format-interpolation
                   f'img/sec/core: {img_sec_core_train:.1f}, '
                   f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')

    # Run evaluation
    if ((config.eval_every and step % config.eval_every == 0) or
        (step == total_steps)):

      accuracies = []
      lt0 = time.time()
      for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
        logits = infer_fn_repl(
            dict(params=opt_repl.target), test_batch['image'])
        accuracies.append(
            (np.argmax(logits,
                       axis=-1) == np.argmax(test_batch['label'],
                                             axis=-1)).mean())
      accuracy_test = np.mean(accuracies)
      img_sec_core_test = (
          config.batch_eval * ds_test.cardinality().numpy() /
          (time.time() - lt0) / jax.device_count())
      lt0 = time.time()

      lr = float(lr_fn(step))
      logging.info(f'Step: {step} '  # pylint: disable=logging-format-interpolation
                   f'Learning rate: {lr:.7f}, '
                   f'Test accuracy: {accuracy_test:0.5f}, '
                   f'img/sec/core: {img_sec_core_test:.1f}')
      writer.write_scalars(
          step,
          dict(
              accuracy_test=accuracy_test,
              lr=lr,
              img_sec_core_test=img_sec_core_test))

    # Store checkpoint.
    if ((config.checkpoint_every and step % config.eval_every == 0) or
        step == total_steps):
      checkpoint_path = flax_checkpoints.save_checkpoint(
          workdir, (flax.jax_utils.unreplicate(opt_repl), step), step)
      logging.info('Stored checkpoint at step %d to "%s"', step,
                   checkpoint_path)

  return flax.jax_utils.unreplicate(opt_repl)
示例#15
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if config.batch_size % n_devices:
        raise ValueError(
            "Batch size must be divisible by the number of devices")

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=config.dataset_name,
        eval_dataset_name=config.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        vocab_path=vocab_path,
        target_vocab_size=config.vocab_size,
        batch_size=config.batch_size,
        max_corpus_chars=config.max_corpus_chars,
        max_length=config.max_target_length,
        max_eval_length=config.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    if config.num_predict_steps > 0:
        predict_ds = predict_ds.take(config.num_predict_steps)

    logging.info("Initializing model, optimizer, and step functions.")

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=config.share_embeddings,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    input_shape = (config.batch_size, config.max_target_length)
    target_shape = (config.batch_size, config.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 1:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=config.label_smoothing),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=config.label_smoothing),
                           axis_name="batch")
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=config.max_predict_length,
        config=predict_config),
                            axis_name="batch")
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=config.beam_size),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5)
        ]
    metrics_all = []
    with metric_writers.ensure_flushes(writer):
        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            # Shard data to devices and do a training step.
            batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
            optimizer, metrics, dropout_rngs = p_train_step(
                optimizer, batch, dropout_rng=dropout_rngs)
            metrics_all.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Save a checkpoint on one host after every checkpoint_freq steps.
            if (config.save_checkpoints and step % config.checkpoint_freq == 0
                    and step > 0 and jax.host_id() == 0):
                checkpoints.save_checkpoint(workdir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

            # Periodic metric handling.
            if step % config.eval_frequency != 0 and step > 0:
                continue

            # Training Metrics
            logging.info("Gathering training metrics.")
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop("learning_rate").mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop("denominator")
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary["learning_rate"] = lr
            summary = {"train_" + k: v for k, v in summary.items()}
            writer.write_scalars(step, summary)
            metrics_all = []

            # Eval Metrics
            logging.info("Gathering evaluation metrics.")
            eval_metrics = []
            eval_iter = iter(eval_ds)
            for _, eval_batch in zip(range(config.num_eval_steps), eval_iter):
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                eval_batch = common_utils.shard(eval_batch)
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop("denominator")
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            eval_summary = {"eval_" + k: v for k, v in eval_summary.items()}
            writer.write_scalars(step, eval_summary)

            # Translation and BLEU Score.
            logging.info("Translating evaluation dataset.")
            t_inference_start = time.time()
            sources, references, predictions = [], [], []
            for pred_batch in predict_ds:
                pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch["inputs"].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    pred_batch = jax.tree_map(
                        lambda x: pad_examples(x, padded_size),  # pylint: disable=cell-var-from-loop
                        pred_batch)
                pred_batch = common_utils.shard(pred_batch)
                cache = p_init_cache(pred_batch["inputs"])
                predicted = p_pred_step(pred_batch["inputs"], optimizer.target,
                                        cache, eos_id,
                                        config.max_predict_length)
                predicted = tohost(predicted)
                inputs = tohost(pred_batch["inputs"])
                targets = tohost(pred_batch["targets"])
                # Iterate through non-padding examples of batch.
                for i, s in enumerate(predicted[:cur_pred_batch_size]):
                    sources.append(decode_tokens(inputs[i]))
                    references.append(decode_tokens(targets[i]))
                    predictions.append(decode_tokens(s))
            logging.info(
                "Translation: %d predictions %d references %d sources.",
                len(predictions), len(references), len(sources))
            logging.info("Translation time: %.4f s step %d.",
                         time.time() - t_inference_start, step)

            # Calculate BLEU score for translated eval corpus against reference.
            bleu_matches = bleu.bleu_partial(references, predictions)
            all_bleu_matches = per_host_sum_pmap(bleu_matches)
            bleu_score = bleu.complete_bleu(*all_bleu_matches)
            # Save translation samples for tensorboard.
            exemplars = ""
            for n in np.random.choice(np.arange(len(predictions)), 8):
                exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n"
            writer.write_scalars(step, {"bleu": bleu_score})
            writer.write_texts(step, {"samples": exemplars})
示例#16
0
def train_and_evaluate(config, workdir):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
    # Seed for reproducibility.
    rng = jax.random.PRNGKey(config.rng_seed)

    # Set up logging.
    summary_writer = metric_writers.create_default_writer(workdir)
    summary_writer.write_hparams(dict(config))

    # Get datasets.
    rng, dataset_rng = jax.random.split(rng)
    dataset = input_pipeline.get_dataset(config, dataset_rng)
    graph, labels, masks = jax.tree_map(jnp.asarray, dataset)
    labels = jax.nn.one_hot(labels, config.num_classes)
    train_mask = masks['train']
    train_indices = jnp.where(train_mask)[0]
    train_labels = labels[train_indices]
    num_training_nodes = len(train_indices)

    # Get subgraphs.
    if config.differentially_private_training:
        graph = jax.tree_map(np.asarray, graph)
        subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to)
        graph = jax.tree_map(jnp.asarray, graph)

        # We only need the subgraphs for training nodes.
        train_subgraphs = subgraphs[train_indices]
        del subgraphs
    else:
        train_subgraphs = None

    # Initialize privacy accountant.
    training_privacy_accountant = privacy_accountants.get_training_privacy_accountant(
        config, num_training_nodes, compute_max_terms_per_node(config))

    # Construct and initialize model.
    rng, init_rng = jax.random.split(rng)
    estimation_indices = get_estimation_indices(train_indices, config)
    state = create_train_state(init_rng, config, graph, train_labels,
                               train_subgraphs, estimation_indices)

    # Set up checkpointing of the model.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Log overview of parameters.
    parameter_overview.log_parameter_overview(state.params)

    # Log metrics after initialization.
    logits = compute_logits(state, graph)
    metrics_after_init = compute_metrics(logits, labels, masks)
    metrics_after_init['epsilon'] = 0
    log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init')

    # Train model.
    rng, train_rng = jax.random.split(rng)
    max_training_epsilon = get_max_training_epsilon(config)

    # Hooks called periodically during training.
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_training_steps, writer=summary_writer)
    profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
    hooks = [report_progress, profiler]

    for step in range(initial_step, config.num_training_steps):

        # Perform one step of training.
        with jax.profiler.StepTraceAnnotation('train', step_num=step):
            # Sample batch.
            step_rng = jax.random.fold_in(train_rng, step)
            indices = jax.random.choice(step_rng, num_training_nodes,
                                        (config.batch_size, ))

            # Compute gradients.
            if config.differentially_private_training:
                grads = compute_updates_for_dp(state, graph, train_labels,
                                               train_subgraphs, indices,
                                               config.adjacency_normalization)
            else:
                grads = compute_updates(state, graph, train_labels, indices)

            # Update parameters.
            state = update_model(state, grads)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 10,
                            step)
        for hook in hooks:
            hook(step)

        # Evaluate, if required.
        is_last_step = (step == config.num_training_steps - 1)
        if step % config.evaluate_every_steps == 0 or is_last_step:
            with report_progress.timed('eval'):
                # Check if privacy budget exhausted.
                training_epsilon = training_privacy_accountant(step + 1)
                if max_training_epsilon is not None and training_epsilon >= max_training_epsilon:
                    break

                # Compute metrics.
                logits = compute_logits(state, graph)
                metrics_during_training = compute_metrics(
                    logits, labels, masks)
                metrics_during_training['epsilon'] = training_epsilon
                log_metrics(step, metrics_during_training, summary_writer)

        # Checkpoint, if required.
        if step % config.checkpoint_every_steps == 0 or is_last_step:
            with report_progress.timed('checkpoint'):
                ckpt.save(state)

    return state
示例#17
0
def evaluate(base_dir, config, *, train_state):
    """Eval function."""
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval'))

    writer = create_default_writer()

    key = jax.random.PRNGKey(config.eval.seed)
    model_init_key, ds_key = jax.random.split(key)

    linear_module = LinearModule(config.eval.num_tasks)
    params = linear_module.init(model_init_key,
                                jnp.zeros((config.encoder.embedding_dim, )))
    lr = optax.cosine_decay_schedule(config.eval.learning_rate,
                                     config.num_eval_steps)
    optim = optax.adam(lr)

    ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks)
    ds_iter = iter(ds)

    state = TrainState.create(apply_fn=linear_module.apply,
                              params=params,
                              tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_eval_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = EvalMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_eval_steps)):
            with jax.profiler.StepTraceAnnotation('eval', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = evaluate_step(train_state, state, metrics,
                                               states, targets)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = EvalMetrics.empty()

            for hook in hooks:
                hook(step)

        # Finally, evaluate on the true(ish) test aux task matrix.
        states, targets = dataset.EvalDataset(config, ds_key).get_batch()

        @jax.jit
        def loss_fn():
            outputs = train_state.apply_fn(train_state.params, states)
            phis = outputs.phi
            predictions = jax.vmap(state.apply_fn,
                                   in_axes=(None, 0))(state.params, phis)
            return jnp.mean(optax.l2_loss(predictions, targets))

        test_loss = loss_fn()
        writer.write_scalars(config.num_eval_steps + 1,
                             {'test_loss': test_loss})
def training_loop(
    *,
    module,
    rng,
    train_ds,
    eval_ds,
    loss_fn,
    optimizer,
    train_metrics_dict,
    eval_metrics_dict,
    stats_aggregators,
    config,
    workdir,
):
    """Runs a training and evaluation loop.

  Args:
    module: The module that should be trained.
    rng: A jax pseudo-random number generator key.
    train_ds: Dataset used for training.
    eval_ds: Dataset used for evaluation.
    loss_fn: Loss function to use for training.
    optimizer: Optax optimizer to use for training.
    train_metrics_dict: Collection of metrics to be collected during training.
    eval_metrics_dict: Collection of metrics to be collected during evaluation.
    stats_aggregators: Dictionary of statistics aggregator functions to be run
      on the first evaluation batch. These functions ingest the stats returned
      by the model and output a Dict[str, image/scalar] that will be logged.
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.

  Raises:
    RuntimeError: If a training metric is NaN or inf.

  Returns:
    Training state.
  """
    rng, model_rng = jax.random.split(rng)
    input_shape = tuple(train_ds.element_spec["image"].shape[1:])
    model, init_params, init_state = create_model(module, input_shape,
                                                  model_rng)
    parameter_overview.log_parameter_overview(model.params)

    # Load a pretrained model parameters and state. Ignore the step and the
    # optimizer state in the checkpoint.
    pretrained_path = config.get("pretrained_checkpoint", "")
    if pretrained_path:
        logging.info("Load pretrained weights from '%s'", pretrained_path)
        state_dict = checkpoint.load_state_dict(pretrained_path)
        flatten_model_params = utils.flatten_dict(state_dict["model_params"],
                                                  sep="/")
        model_state = state_dict["model_state"]

        # A prefix can be used to replace only a subpart of the network (e.g the
        # encoder). Prepend the prefix (if any) to model parameters and states.
        prefix = config.get("pretrained_prefix", "")
        if prefix:
            flatten_model_params = utils.add_prefix_to_dict_keys(
                flatten_model_params, f"{prefix}/")
            model_state = utils.add_prefix_to_dict_keys(
                model_state, f"/{prefix}")

        # Merge the params/state from the checkpoint into the initial params/state.
        flatten_init_params = utils.flatten_dict(init_params, sep="/")
        flatten_init_params, ignored_params = utils.override_dict(
            flatten_init_params, flatten_model_params)
        init_params = utils.unflatten_dict(flatten_init_params, delimiter="/")
        init_state, _ = utils.override_dict(init_state, model_state)

        if ignored_params:
            logging.warning(
                "%d/%d parameters from the pretrained checkpoint "
                "were ignored: %s", len(ignored_params),
                len(flatten_init_params), ignored_params)

    optimizer_state = optimizer.init(init_params)

    state = TrainState(step=1,
                       model_params=init_params,
                       model_state=init_state,
                       optimizer_state=optimizer_state)  # type: ignore
    # Do not keep a copy of the initial model.
    del init_params, init_state, optimizer_state

    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    checkpoint_dir = os.path.join(workdir, "checkpoints")

    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step)
    # Replicate our parameters.
    state = flax.jax_utils.replicate(state)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    step_timer = utils.StepTimer(batch_size=config.batch_size,
                                 initial_step=initial_step)

    # Write config to the summary files. This makes the hyperparameters available
    # in TensorBoard and makes comparison of runs with tensorboard/ easier.
    if initial_step == 1:
        writer.write_hparams(utils.flatten_dict(config.to_dict()))

    # Generate per-device PRNG keys for the training loop.
    rng, train_rng = jax.random.split(rng)
    train_rngs = jax.random.split(train_rng, jax.local_device_count())

    # Generate per-device PRNG keys for model evaluation.
    rng, eval_rng = jax.random.split(rng)
    eval_rngs = jax.random.split(eval_rng, jax.local_device_count())

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    train_metrics = utils.Means()

    do_eval_only = config.get("do_eval_only", False)
    if do_eval_only:
        config.num_train_steps = 1

    debug_enabled = config.get("debug", False)
    previous_grads = grads = None
    previous_updates = updates = None
    previous_state = None
    for step in range(initial_step, config.num_train_steps + 1):
        is_last_step = step == config.num_train_steps
        if debug_enabled:
            previous_grads = grads
            previous_updates = updates
            previous_state = state

        # Skip the training if only do the eval.
        if not do_eval_only:
            # Use ._numpy() to avoid copy.
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
            state, grads, updates, metrics, training_stats, train_rngs = train_step(
                state, batch, module, loss_fn, optimizer, train_metrics_dict,
                train_rngs)
            train_metrics.append(flax.jax_utils.unreplicate(metrics))

            # Update topk temperature with linearly decreasing schedule if enabled.
            if (config.get("linear_decrease_perturbed_sigma", False) and
                    config.get("selection_method", "") == "perturbed-topk"):

                model_state = state.model_state.as_dict()

                if "/PatchNet_0" in model_state:
                    net_str = "/PatchNet_0"
                else:
                    net_str = "/"

                progress = step / config.num_train_steps
                sigma_multiplier = 1. - progress
                previous_mult = model_state[net_str]["sigma_mutiplier"]
                sigma_multiplier = sigma_multiplier + jnp.zeros_like(
                    previous_mult)
                model_state[net_str]["sigma_mutiplier"] = sigma_multiplier
                state = state.replace(model_state=nn.Collection(model_state))

            if debug_enabled:
                if utils.has_any_inf_or_nan(metrics):
                    # Save checkpoint
                    if previous_state:
                        ckpt.save(flax.jax_utils.unreplicate(previous_state))
                    ckpt.save(flax.jax_utils.unreplicate(state))

                    # Log gradients and updates.
                    if previous_grads or previous_updates:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=previous_grads,
                                                 updates=previous_updates)
                    write_gradient_histogram(writer,
                                             step + 1,
                                             grads=grads,
                                             updates=updates)

                    raise RuntimeError(
                        "A training metric took an invalid value: "
                        f"{metrics}.")

            logging.log_first_n(logging.INFO, "Finished training step %d.", 3,
                                step)
            report_progress(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                results = train_metrics.result()
                writer.write_scalars(step, results)
                writer.write_scalars(step, step_timer.get_and_reset(step))
                if utils.has_any_inf_or_nan(results):
                    raise ValueError(
                        "A training metric took an invalid value.")
                train_metrics.reset()

        if (step % config.checkpoint_every_steps == 0 or is_last_step):
            with step_timer.paused():
                ckpt.save(flax.jax_utils.unreplicate(state))

        # Evaluation
        if step % config.eval_every_steps == 0 or is_last_step:
            with step_timer.paused():
                eval_metrics, first_batch_stats, eval_rngs = evaluate(
                    state, module, eval_ds, eval_metrics_dict, eval_rngs)

            if jax.host_id() == 0:
                log_histograms = config.get("log_histograms", False)
                log_images = config.get("log_images", True)
                # Log the last gradients and updates histograms.
                if not do_eval_only:
                    write_stats_results(writer,
                                        step,
                                        training_stats,
                                        stats_aggregators,
                                        prefix="train/",
                                        log_images=log_images)
                    if log_histograms:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=grads,
                                                 updates=updates)

                write_stats_results(writer,
                                    step,
                                    first_batch_stats,
                                    stats_aggregators,
                                    prefix="eval/",
                                    log_images=log_images)

                # write patch representation histograms
                if (log_histograms and first_batch_stats
                        and "patch_representations" in first_batch_stats):
                    patch_representations = first_batch_stats[
                        "patch_representations"]
                    writer.write_histograms(
                        step, {"patch_representations": patch_representations})

                if eval_metrics:
                    writer.write_scalars(step, eval_metrics)

    writer.flush()
    return state
示例#19
0
def train_and_evaluate(config, workdir):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  is_first_process = jax.process_index() == 0
  tf.io.gfile.makedirs(workdir)

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets(
      config)
  config.seq_length = 250
  vocab_size = int(encoder.vocab_size())
  config.num_classes = vocab_size
  config.data_shape = (config.seq_length, 1)

  logging.info('Training with vocab size %d', vocab_size)

  def decode_tokens(toks):
    return encoder.detokenize(toks)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  config.per_device_batch_size = config.batch_size // jax.process_count()

  logging.info('Initializing model, optimizer, and step functions.')
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  model, initial_variables = model_setup(init_rng, config)

  # Instead of passing the optimizer fns directly, we use a fn that returns
  # the optimizer given a learning rate.
  def tx_fn(lr):
    return optax.adamw(
        lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0,
        weight_decay=config.weight_decay)

  state = language_train_state.TrainState.create(
      params=initial_variables['params'], tx_fn=tx_fn)

  # We access model params only from state below via state.params.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated model state from last checkpoint.
    state = checkpoints.restore_checkpoint(workdir, state)
    # Grab last step.
    start_step = int(state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=not is_first_process)
  if start_step == 0:
    config_dict = dict(config)
    writer.write_hparams(config_dict)

  if is_first_process and start_step == 0:
    # Dump config file to work dir for easy model loading.
    config_path = os.path.join(workdir, 'config')
    with tf.io.gfile.GFile(config_path, 'wb') as fp:
      pickle.dump(config, fp)

  print('Using state', type(state))
  # Replicate state.
  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_scheduler(
      factors=config.lr_factors,
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # Compile multidevice versions of train/eval/predict step fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model=model,
          learning_rate_fn=learning_rate_fn,
          clip_grad=config.clip_grad,
          ema_momentum=config.get('ema_momentum', 0.999)),
      axis_name='batch',
      in_axes=(0, 0),
      donate_argnums=(0,))
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, model=model),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of train PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  rng = jax.random.fold_in(rng, jax.process_index())
  rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5)
  train_rngs = jax.random.split(rng1, jax.local_device_count())
  eval_rngs = jax.random.split(rng2, jax.local_device_count())
  test_rngs = jax.random.split(rng3, jax.local_device_count())
  del rng, rng1, rng2, rng3

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if is_first_process:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []

  # Iterator that does epoch-wise indefinite iteration.
  def iterate_train(train_ds):
    epoch = 1
    while True:
      msg = f'Starting epoch {epoch}'
      logging.info(msg)
      for batch in train_ds:
        yield batch
      epoch += 1

  train_iter = iterate_train(train_ds)

  kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
  kl_history = []

  with metric_writers.ensure_flushes(writer):
    step = start_step
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        state, metrics = p_train_step(
            state, batch, rng=train_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step > 0 and (step % config.eval_every_steps == 0 or is_last_step):
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)

          # First handle loss terms per step.
          t_batch = train_metrics.pop('t_batch')
          nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch')
          kl_tracker_train.update(
              t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1))
          kl_values = kl_tracker_train.get_kl_per_t()
          kl_history.append(np.array(kl_values))
          kl_history = kl_history[-100:]  # Keep last 100 items only.

          # Handle remaining `standard` metrics
          summary = jax.tree_map(jnp.mean, train_metrics)
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results, eval_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=eval_ds,
              rng=eval_rngs)
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

          test_results, test_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=test_ds,
              rng=test_rngs)
          writer.write_scalars(
              step, {'test_' + k: v for k, v in test_results.items()})

        if step == 1000 or (step > 0 and
                            step % config.detailed_eval_every_steps == 0):
          if is_first_process:
            loss_components_path = os.path.join(workdir, 'loss_components')
            with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
              pickle.dump(kl_history[-1], fp)

          extensive_eval_rngs = extensive_eval(
              config, extensive_eval_rngs, writer, workdir,
              model, state, kl_history, test_ds, step,
              decode_tokens)

        with report_progress.timed('generate_text'):
          generate_prediction(sample_rng, config, model, state, writer,
                              decode_tokens, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step > 0 and
          (step % config.checkpoint_every_steps == 0 or is_last_step))
      if config.save_checkpoints and save_checkpoint and is_first_process:
        with report_progress.timed('checkpoint'):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state),
                                      step, overwrite=True)
示例#20
0
def train_and_evaluate(config, workdir, strategy):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
    strategy: Distribution strategy to use for distributing the model.
  """
    tf.io.gfile.makedirs(workdir)

    tf_rng, data_rng = tf.random.experimental.stateless_split((config.seed, 0),
                                                              2)
    tf.random.set_seed(tf_rng.numpy()[0])

    # Input pipeline.
    ds_info, train_ds, val_ds, test_ds = input_pipeline.create_datasets(
        config, data_rng, strategy=strategy)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = (ds_info.splits["train"].num_examples //
                           config.global_batch_size * config.num_epochs)
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    # We treat the learning rate in the config as the learning rate for batch size
    # 256 but scale it according to our batch size.
    base_learning_rate = config.learning_rate * config.global_batch_size / 256.0

    # Initialize model.
    num_classes = ds_info.features["label"].num_classes

    if config.distill_teacher:
        do_distill = True
        teacher_file_list = (config.distill_teacher).split(",")
        teacher_models = load_teacher_models(teacher_file_list, num_classes,
                                             config, strategy)
        distill_params = {}
        distill_params["alpha"] = config.distill_alpha
        distill_params["beta"] = config.distill_fd_beta
        distill_params["teacher_model"] = TeacherModel(teacher_models,
                                                       name="teacher")
    else:
        do_distill = False
        distill_params = None

    state = create_state(config, num_classes=num_classes, strategy=strategy)

    ckpt_manager = tf.train.CheckpointManager(checkpoint=state,
                                              directory=workdir,
                                              max_to_keep=5)

    if ckpt_manager.latest_checkpoint:
        state.restore(ckpt_manager.latest_checkpoint)
        logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
    else:
        logging.info("Initializing from scratch.")
    initial_step = state.global_step.numpy().item()

    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    writer = metric_writers.create_default_writer(workdir)
    writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            state.model.trainable = True

            # `step` is a Python integer. `global_step` is a TF variable on the
            # GPU/TPU devices.
            is_last_step = step == num_train_steps

            train_step(state, train_iter, config.weight_decay,
                       learning_rate_fn, do_distill, distill_params, strategy)

            state.train_metrics.update_state_lr(
                learning_rate_fn(state.global_step.numpy().item()))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            report_progress(step)

            if step == initial_step:
                parameter_overview.log_parameter_overview(state.model)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, state.train_metrics.result())
                state.train_metrics.reset_states()
                state.train_metrics.reset_lr()

            if step % config.eval_every_steps == 0 or is_last_step:
                state.model.trainable = False
                if config.dataset == "imagenet-lt":
                    evaluate(state, val_ds, state.val_metrics, strategy)
                    writer.write_scalars(step, state.val_metrics.result())
                    logging.info("Num val images %d",
                                 state.val_metrics.accuracy.count.numpy())

                evaluate(state, test_ds, state.test_metrics, strategy)
                writer.write_scalars(step, state.test_metrics.result())

                logging.info("Num test images %d",
                             state.test_metrics.accuracy.count.numpy())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                checkpoint_path = ckpt_manager.save(step)
                logging.info("Saved checkpoint %s", checkpoint_path)

    logging.info("Finishing training at step %d", step)
    logging.info("Saving the final weights")
    file_path = "%s/final_weights" % workdir
    state.model.save_weights(file_path, save_format="tf")
示例#21
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    splits = input_pipeline.create_datasets(config, data_rng)
    num_classes = splits.info.features["label"].num_classes
    train_iter = iter(splits.train)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = splits.train.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)
    # We treat the learning rate in the config as the learning rate for batch size
    # 32 but scale it according to our batch size.
    global_batch_size = config.per_device_batch_size * jax.device_count()
    base_learning_rate = config.learning_rate * global_batch_size / 32.0
    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = create_train_state(
        config,
        model_rng,
        input_shape=splits.train.element_spec["input"].shape[1:],
        num_classes=num_classes)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir,
                                          {"train_iter": train_iter},
                                          max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Count number of trainable parameters. This must be done before replicating
    # the state to avoid double-counting replicated parameters.
    param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target))

    # Distribute training over local devices.
    state = flax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        model=model,
        learning_rate_fn=learning_rate_fn,
        weight_decay=config.weight_decay),
                            axis_name=_PMAP_AXIS_NAME)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))
        # Log the number of trainable params.
        writer.write_scalars(initial_step, {"param_count": param_count})

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics = None
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps

            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics_update = p_train_step(state=state, batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            # When combining train and eval, we do not evaluate while training.
            if ((step % config.eval_every_steps == 0 or is_last_step)
                    and not config.combine_train_val_and_eval_on_test):
                with report_progress.timed("eval"):
                    eval_metrics = evaluate(model, state, splits.validation,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_metrics.compute())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                with report_progress.timed("checkpoint"):
                    ckpt.save(flax_utils.unreplicate(state))

            if is_last_step and config.combine_train_val_and_eval_on_test:
                # Evaluate a single time on the test set when requested.
                with report_progress.timed("test"):
                    test_metrics = evaluate(model, state, splits.test,
                                            config.num_eval_steps)
                writer.write_scalars(step, test_metrics.compute())

    logging.info("Finishing training at step %d", num_train_steps)
示例#22
0
def train(config: ml_collections.ConfigDict):
  """Run training."""

  # Establish host information
  local_device_count = jax.local_device_count()
  host_count = jax.process_count()
  host_id = jax.process_index()

  task = task_registry.get_registered_task(config.task_name)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)

  model_config = ml_collections.FrozenConfigDict(config.model_config)
  model = task.build_model(model_config)

  # Initialization needs to be pmapped because models use collective ops.
  # Create dummy input
  dummy_input = {
      key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim)
      for key, value in task.dummy_input(config).items()
  }

  rng, init_rng = jax.random.split(rng)
  init_rng = jax.random.split(init_rng, local_device_count)

  logging.info('Initializing model.')
  initial_variables = jax.pmap(
      model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input,
                                                         True)
  logging.info('Finished initializing model.')
  initial_variables = initial_variables.unfreeze()

  if config.load_weights is not None:
    logging.info('Loading model weights from file')
    loaded_variables = task.load_weights(config)
    unexpected, missing = checkpoint_utils.merge_nested_dicts(
        initial_variables, loaded_variables)
    logging.info('*** Unexpected features: ***')
    for feature_name in unexpected:
      logging.info('\t%s', feature_name)
    logging.info('*** Missing features: ***')
    for feature_name in missing:
      logging.info('\t%s', feature_name)

  model_vars = {
      key: value for key, value in initial_variables.items() if key != 'params'
  }

  learning_rate_fn = optim_utils.create_learning_rate_scheduler(
      learning_rate=config.learning_rate,
      warmup=config.warmup,
      warmup_steps=config.get('warmup_steps', None),
      linear_decay=config.linear_decay,
      max_steps=config.num_train_steps,
      decay_minimum_factor=config.get('decay_minimum_factor', None),
  )

  if config.weight_decay_exclude is not None:
    decay_mask = optim_utils.create_dict_mask(initial_variables['params'],
                                              config.weight_decay_exclude)
  else:
    decay_mask = None
  tx = optax.adamw(
      learning_rate=learning_rate_fn,
      weight_decay=config.weight_decay,
      b1=0.9,
      b2=0.999,
      eps=1e-6,
      mask=decay_mask)
  if config.grad_clip is not None:
    tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip))

  ignore_k_nans = config.get('ignore_k_nans')
  if ignore_k_nans is not None:
    tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans)

  loss_fn = task.make_loss_fn(config)
  train_state = ts.TrainState.create(
      apply_fn=loss_fn,
      params=jax_utils.unreplicate(initial_variables['params']),
      tx=tx,
  )

  # We access model params only from train state.
  del initial_variables

  # Restore unreplicated train state from last checkpoint
  train_state = checkpoints.restore_checkpoint(config.model_dir, train_state)
  # Grab last step.
  start_step = int(train_state.step)

  writer = metric_writers.create_default_writer(
      config.model_dir, just_logging=jax.process_index() > 0)
  if start_step == 0:
    writer.write_hparams(config.to_dict())

  dropout_rngs = jax.random.split(rng, local_device_count)

  del rng

  # Load datasets
  logging.info('Loading dataset.')

  # Make sure we don't re-use same data if we load weights or checkpoint
  seed = config.seed + start_step
  if config.load_weights:
    seed = seed + hash(config.load_weights)

  name_to_features = task.get_name_to_features(config)
  preprocess_fn = task.make_preprocess_fn(config)
  collater_fn = task.make_collater_fn(config)

  train_data = data_utils.load_multi_dataset(
      datasets_config=config.train_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=True,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
  )
  train_iter = iter(train_data)

  pad_eval = config.get('pad_eval', False)
  if pad_eval:
    logging.info('Eval data is padded such that none of samples are dropped.')
  else:
    logging.warn('Eval data is NOT padded -- some samples might be dropped.')

  eval_data = data_utils.load_multi_dataset(
      datasets_config=config.eval_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=False,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
      pad_eval=pad_eval,
  )
  eval_data = list(eval_data)
  logging.info('Loaded %d samples for evaluation.', len(eval_data))

  # Setup postprocessing_fn for saving samples occasionally.
  if config.get('save_samples_every_steps') is not None:
    if config.get('save_samples_every_steps') % config.eval_every_steps != 0:
      raise ValueError(
          '`eval_every_steps` must divide `save_samples_every_steps`.')
    postprocessing_fn = task.make_output_postprocess_fn(config)

  # Training loop
  logging.info('Starting training.')

  # Replicate train state.
  train_state = jax_utils.replicate(train_state)

  # compile multidevice versions of train/eval/predict step
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model_config=model_config,
      ),
      axis_name='batch',
      donate_argnums=(0,),
  )  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          model_config=model_config,
      ),
      axis_name='batch')

  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)

  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and perform a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = jax.tree_map(jnp.asarray, train_iter.get_next())
        train_state, metrics = p_train_step(
            train_state,
            model_vars,
            batch,
            dropout_rngs,
        )
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step)
      for h in hooks:
        h(step)

        # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          summary = metric_utils.process_metrics(metrics_sums, prefix='train')
          summary['learning_rate'] = learning_rate_fn(step)

          writer.write_scalars(step, summary)
          train_metrics = []

          with report_progress.timed('eval'):
            eval_results, eval_auxiliary = evaluate(
                eval_step_fn=p_eval_step,
                train_state=train_state,
                model_vars=model_vars,
                eval_data=eval_data,
            )
            writer.write_scalars(step, eval_results)

            if config.get('save_samples_every_steps') is not None:
              with report_progress.timed('save_samples'):
                if config.get('save_first_batch_only', 'True'):
                  postprocessing_input = [eval_auxiliary[0]]
                eval_processed = [
                    postprocessing_fn(batch, auxiliary_output)
                    for batch, auxiliary_output in eval_auxiliary
                ]
                data_utils.save_samples_to_json(eval_processed, config, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step % config.checkpoint_every_steps == 0 or is_last_step)
      if (config.save_checkpoints and save_checkpoint and
          jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving checkpoint at step %s', step)
          checkpoints.save_checkpoint(
              config.model_dir,
              jax_utils.unreplicate(train_state),
              step,
              keep=config.get('keep_checkpoints', 1),
              keep_every_n_steps=config.get('keep_checkpoint_every_steps'),
          )

      save_model = (
          config.save_every_steps and
          (step % config.save_every_steps == 0 or is_last_step) and step != 0)
      if (save_model and jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving weights at step %s', step)
          save_path = os.path.join(config.model_dir, 'weights',
                                   'step' + str(step))
          # By default, save only encoder weights
          weights = jax_utils.unreplicate(train_state).params['encoder']
          checkpoint_utils.save_weights(save_path, weights)
示例#23
0
def train(base_dir, config):
    """Train function."""
    print(config)
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train'))

    writer = create_default_writer()

    # Initialize dataset
    key = jax.random.PRNGKey(config.seed)
    key, subkey = jax.random.split(key)
    ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks)
    ds_iter = iter(ds)

    key, subkey = jax.random.split(key)
    encoder = MLPEncoder(**config.encoder)

    train_config = config.train.to_dict()
    train_method = train_config.pop('method')

    module_config = train_config.pop('module')
    module_class = module_config.pop('name')

    module = globals().get(module_class)(encoder, **module_config)
    train_step = globals().get(f'train_step_{train_method}')
    train_step = functools.partial(train_step, **train_config)

    params = module.init(subkey, next(ds_iter)[0])
    lr = optax.cosine_decay_schedule(config.learning_rate,
                                     config.num_train_steps)
    optim = optax.chain(optax.adam(lr),
                        # optax.adaptive_grad_clip(0.15)
                        )

    state = TrainState.create(apply_fn=module.apply, params=params, tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    # Hooks
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = TrainMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_train_steps)):
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = train_step(state, metrics, states, targets)

            logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                                step)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = TrainMetrics.empty()

            # if step % config.log_eval_metrics_every == 0 and isinstance(
            #     ds, dataset.MDPDataset):
            #   eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config)
            #   writer.write_scalars(step, eval_metrics.compute())

            for hook in hooks:
                hook(step)

    chkpt_manager.save(state)
    return state
示例#24
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    if config.dataset.name.endswith('speech_commands09'):
        ds, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config)
    else:
        raise ValueError(f'Unknown dataset {config.dataset.name}.')

    # Immediately create infinite iterators.
    it = jax.tree_map(util_fns.get_iterator, ds)

    # TODO(agritsenko): Can we fix the ugly nested dicts?
    config.data_shape = ds_metadata['train']['shape']['inputs'][2:]
    config.num_classes = ds_metadata['train']['num_classes']
    config.sample_rate = ds_metadata['train']['sample_rate']

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    def tx_fn(lr):
        return optax.adamw(lr,
                           b1=0.9,
                           b2=config.beta2,
                           eps=1e-08,
                           eps_root=0.0,
                           weight_decay=config.weight_decay)

    state = language_train_state.TrainState.create(params=variables['params'],
                                                   tx_fn=tx_fn)

    start_step = None
    if try_checkpoint:
        state, start_step = checkpoint.restore_from_path(work_dir, state)
    start_step = start_step or 0

    # Use different rngs for train & eval.
    rng_train, rng_eval, rng_sample = jax.random.split(rng, 3)

    kl_tracker = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        **config.learning_rate)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=config,
        learning_rate_fn=learning_rate_fn,
        model=model),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, 0, None))

    # Training length.
    logging.info('Training will start from step %d', start_step)

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    # Setup hooks.
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if is_first_host:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=work_dir, num_profile_steps=5)
        ]

    with metric_writers.ensure_flushes(writer):
        batch_metrics = []
        for step in range(start_step, config.num_train_steps):
            logging.log_first_n(logging.INFO, f'Train step: {step}', 5)
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                state, metrics, rng_train = p_train_step(
                    rng_train, next(it['train']), state)
            batch_metrics.append(metrics)

            # Cycle though hooks.
            for h in hooks:
                h(step)

            is_last_step = step == config.num_train_steps - 1

            if (step % config.log_every_steps == 0) or is_last_step:
                with report_progress.timed('training_metrics'):
                    ################### Process batch metrics ############################
                    batch_metrics = jax.device_get(
                        flax.jax_utils.unreplicate(batch_metrics))

                    if 't_batch' in metrics:
                        # TODO(agritsenko): Factor out into a separate function.
                        # This processes the loss per t, although two nested for-loops
                        # (counting the one inside kl_tracker), it actually does not hurt
                        # timing performance meaningfully.
                        batch_t = [
                            metrics['t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        batch_nelbo_per_t = [
                            metrics['nelbo_per_t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        for t, nelbo_per_t in zip(batch_t, batch_nelbo_per_t):
                            kl_tracker.update(t, nelbo_per_t)

                    ################### Process batch metrics ############################
                    metrics = {
                        key:
                        np.mean([metrics[key] for metrics in batch_metrics])
                        for key in batch_metrics[0] if 'batch' not in key
                    }

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             train_metrics=metrics)
                    batch_metrics = []

            if config.eval_every_steps and (
                (step % config.eval_every_steps == 0) or is_last_step):
                with report_progress.timed('eval'):
                    ####################### Run evaluation ###############################
                    metrics, rng_eval = eval_model(
                        p_eval_step, rng_eval, state, it['eval'],
                        (ds_metadata['eval']['num_batches'] *
                         config.get('num_eval_passes', 1)))

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             eval_metrics=metrics)

                # Track KL (unrelated to the eval, but nice to not do every step).
                kl_values = kl_tracker.get_kl_per_t()
                kl_history.append(np.array(kl_values))
                kl_history = kl_history[-50:]

            if config.sample_every_steps and (
                (step % config.sample_every_steps == 0) or is_last_step):
                with report_progress.timed('sample'):
                    ######################### Run sampling ###############################
                    chain = model.sample(jax.random.fold_in(rng_sample, step),
                                         state.ema_params,
                                         config.sample_batch_size,
                                         chain_out_size=config.get(
                                             'chain_out_size',
                                             model.num_stages))

                    if is_first_host:
                        chain = jax.device_get(chain)
                        long_sample = np.reshape(chain[-1],
                                                 (1, -1, 1)).astype(np.float32)
                        long_sample = (2. *
                                       long_sample) / config.num_classes - 1.
                        writer.write_audios(step, {'samples': long_sample},
                                            sample_rate=config.sample_rate)

            ######################### Checkpointing #################################
            if is_first_host and config.checkpoint_every_steps and (
                (step % config.checkpoint_every_steps == 0) or is_last_step):
                logging.info('Saving checkpoint: step %d', step)
                with report_progress.timed('checkpoint'):
                    checkpoint.save_checkpoint(
                        work_dir,
                        state=flax.jax_utils.unreplicate(state),
                        step=step)
                logging.info('Finished saving checkpoint: step %d', step)

        return state
示例#25
0
文件: main.py 项目: myagues/flax_nerf
def main(_):
    if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching:
        raise ValueError(
            "'precrop_iters has no effect when 'batching' the dataset")
    assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0

    logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.model_dir, "model_dir")

    os.makedirs(FLAGS.model_dir, exist_ok=True)
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5)
    rngs = common_utils.shard_prng_key(step_rng)

    ### Load dataset and data values
    datasets, counts, optics, render_datasets = get_dataset(
        FLAGS.data_dir,
        FLAGS.config,
        rng=data_rng,
        num_poses=FLAGS.config.num_poses)
    train_ds, val_ds, test_ds = datasets
    *_, test_items = counts
    hwf, r_hwf, near, far = optics
    render_ds, render_vdirs_ds, num_poses = render_datasets
    iter_render_ds = zip(range(num_poses), render_ds)
    iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds)
    iter_test_ds = zip(range(test_items), test_ds)
    img_h, img_w, _ = hwf

    logging.info("Num poses: %d", num_poses)
    logging.info("Splits: train - %d, val - %d, test - %d", *counts)
    logging.info("Images: height %d, width %d, focal %.5f", *hwf)
    logging.info("Render: height %d, width %d, focal %.5f", *r_hwf)

    ### Init model parameters and optimizer
    initialized_ = functools.partial(initialized,
                                     model_config=FLAGS.config.model)
    pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized_(rng_coarse, pts_shape,
                                               views_shape)

    schedule_fn = optax.exponential_decay(
        init_value=FLAGS.config.learning_rate,
        transition_steps=FLAGS.config.lr_decay * 1000,
        decay_rate=FLAGS.config.decay_factor,
    )
    tx = optax.adam(learning_rate=schedule_fn)
    state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None),
                                          params={"coarse": params_coarse},
                                          tx=tx)

    if FLAGS.config.num_importance > 0:
        pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized_(rng_fine, pts_shape,
                                               views_shape)
        state = train_state.TrainState.create(
            apply_fn=(model_coarse.apply, model_fine.apply),
            params={
                "coarse": params_coarse,
                "fine": params_fine
            },
            tx=tx,
        )

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    start_step = int(state.step)

    # cycle already seen examples if resuming from checkpoint
    # (only useful for ensuring deterministic dataset, slow for large start_step)
    # if start_step != 0:
    #     for _ in range(start_step):
    #         _ = next(train_ds)

    # parameter_overview.log_parameter_overview(state.optimizer_coarse.target)
    # if FLAGS.config.num_importance > 0:
    #     parameter_overview.log_parameter_overview(state.optimizer_fine.target)

    state = jax.device_put_replicated(state, jax.local_devices())

    ### Build "pmapped" functions for distributed training
    train_fn = functools.partial(train_step, near, far, FLAGS.config,
                                 schedule_fn)
    p_train_step = jax.pmap(
        train_fn,
        axis_name="batch",
        in_axes=(0, 0, None, 0),
        # donate_argnums=(0, 1, 2),
    )

    def render_fn(state, rays):
        step_fn = functools.partial(eval_step, FLAGS.config, near, far, state)
        return lax.map(step_fn, rays)

    p_eval_step = jax.pmap(
        render_fn,
        axis_name="batch",
        # in_axes=(0, 0, None),
        # donate_argnums=(0, 1))
    )

    # TODO: add hparams
    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)
    logging.info("Starting training loop.")

    hooks = []
    profiler = periodic_actions.Profile(num_profile_steps=5,
                                        logdir=FLAGS.model_dir)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.config.num_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [profiler, report_progress]
    train_metrics = []
    gen_video_ = functools.partial(gen_video, FLAGS.model_dir)

    for step in range(start_step, FLAGS.config.num_steps + 1):
        is_last_step = step == FLAGS.config.num_steps

        batch = next(train_ds)
        coords = None
        if not FLAGS.config.batching:
            coords = jnp.meshgrid(jnp.arange(img_h),
                                  jnp.arange(img_w),
                                  indexing="ij")
            if step < FLAGS.config.precrop_iters:
                dH = int(img_h // 2 * FLAGS.config.precrop_frac)
                dW = int(img_w // 2 * FLAGS.config.precrop_frac)
                coords = jnp.meshgrid(
                    jnp.arange(img_h // 2 - dH, img_h // 2 + dH),
                    jnp.arange(img_w // 2 - dW, img_w // 2 + dW),
                    indexing="ij",
                )
            coords = jnp.stack(coords, axis=-1).reshape([-1, 2])

        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            state, metrics = p_train_step(batch, state, coords, rngs)
        train_metrics.append(metrics)

        logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                            step)
        _ = [h(step) for h in hooks]

        ### Write train summaries to TB
        if step % FLAGS.config.i_print == 0 or is_last_step:
            with report_progress.timed("training_metrics"):
                train_metrics = common_utils.get_metrics(train_metrics)
                train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
                summary = {f"train/{k}": v for k, v in train_summary.items()}
                writer.write_scalars(step, summary)
            train_metrics = []

        ### Eval a random validation image and plot it to TB
        if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step:
            with report_progress.timed("validation"):
                inputs = next(val_ds)
                rays, padding = prepare_render_data(inputs["rays"]._numpy())
                outputs = p_eval_step(state, rays)
                preds, preds_c, z_std = jax.tree_map(
                    lambda x: to_np(x, hwf, padding), outputs)
                loss = np.mean((preds["rgb"] - inputs["image"])**2)
                summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)}
                writer.write_scalars(step, summary)

                summary = {
                    "val/rgb": to_rgb(preds["rgb"]),
                    "val/target": to_np(inputs["image"], hwf, padding),
                    "val/disp": disp_post(preds["disp"], FLAGS.config),
                    "val/acc": preds["acc"],
                }
                if FLAGS.config.num_importance > 0:
                    summary["val/rgb_c"] = to_rgb(preds_c["rgb"])
                    summary["val/disp_c"] = disp_post(preds_c["disp"],
                                                      FLAGS.config)
                    summary["val/z_std"] = z_std
                writer.write_images(step, summary)

        ### Render a video with test poses
        if step % FLAGS.config.i_video == 0 and step > 0:
            with report_progress.timed("video_render"):
                logging.info("Rendering video at step %d", step)
                rgb_list = []
                disp_list = []
                for idx, inputs in tqdm(iter_render_ds, desc="Rays render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding),
                                         preds)
                    rgb_list.append(preds["rgb"])
                    disp_list.append(preds["disp"])

                gen_video_(np.stack(rgb_list), "rgb", r_hwf, step)
                disp = np.stack(disp_list)
                gen_video_(disp_post(disp, FLAGS.config),
                           "disp",
                           r_hwf,
                           step,
                           ch=1)

                if FLAGS.config.use_viewdirs:
                    rgb_list = []
                    for idx, inputs in tqdm(iter_vdirs_ds,
                                            desc="Viewdirs render"):
                        rays, padding = prepare_render_data(
                            inputs["rays"].numpy())
                        preds, *_ = p_eval_step(state, rays)
                        rgb_list.append(to_np(preds["rgb"], r_hwf, padding))
                    gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step)

        ### Save images in the test set
        if step % FLAGS.config.i_testset == 0 and step > 0:
            with report_progress.timed("test_render"):
                logging.info("Rendering test set at step %d", step)
                test_losses = []
                for idx, inputs in tqdm(iter_test_ds, desc="Test render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step,
                                   idx)

                    if FLAGS.config.render_factor == 0:
                        loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                        test_losses.append(loss)
                if FLAGS.config.render_factor == 0:
                    loss = np.mean(test_losses)
                    summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)}
                    writer.write_scalars(step, summary)
        writer.flush()

        ### Save ckpt
        if step % FLAGS.config.i_weights == 0 or is_last_step:
            with report_progress.timed("checkpoint"):
                save_checkpoint(state, FLAGS.model_dir)