Пример #1
0
    def test_init_state(self):
        params = np.zeros((3, 2))
        optimizer_def = optim.Adafactor(learning_rate=0.1,
                                        decay_rate=0.8,
                                        beta1=None,
                                        min_dim_size_to_factor=0)
        state = optimizer_def.init_state(params)

        expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None,
                                                      0.8, 0, 1.0, None, 0,
                                                      1e-30, 1e-3)
        self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
        expected_state = optim.OptimizerState(
            0,
            _AdafactorParamState(np.zeros((2, )), np.zeros((3, )),
                                 np.zeros((1, )), np.zeros((1, ))))
        check_eq(state, expected_state)

        # unfactorized
        optimizer_def = optim.Adafactor(learning_rate=0.1,
                                        decay_rate=0.8,
                                        beta1=0.0,
                                        min_dim_size_to_factor=32)
        state = optimizer_def.init_state(params)

        expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0,
                                                      0.8, 0, 1.0, None, 32,
                                                      1e-30, 1e-3)
        self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
        expected_state = optim.OptimizerState(
            0,
            _AdafactorParamState(np.zeros((1, )), np.zeros((1, )),
                                 np.zeros((3, 2)), np.zeros((3, 2))))
        check_eq(state, expected_state)
Пример #2
0
  def test_apply_gradient(self):
    optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8,
                                    min_dim_size_to_factor=0)
    params = onp.ones((3, 2), onp.float32)
    state = optim.OptimizerState(
        1, _AdafactorParamState(onp.array([0.9, 0.9]),
                                onp.array([0.1, 0.1, 0.1]),
                                onp.zeros((1,)),
                                onp.zeros((1,))))
    grads = onp.ones((3, 2), onp.float32)
    new_params, new_state = optimizer_def.apply_gradient(
        optimizer_def.hyper_params, params, state, grads)
    expected_new_state = optim.OptimizerState(
        2, _AdafactorParamState(
            onp.array([0.9574349, 0.9574349]),
            onp.array([0.6169143, 0.6169143, 0.6169143]),
            onp.zeros((1,)),
            onp.zeros((1,))))
    expected_new_params = 0.9 * onp.ones((3, 2))
    onp.testing.assert_allclose(new_params, expected_new_params)
    check_eq(new_state, expected_new_state, rtol=1e-6)

    # unfactored w momentum
    optimizer_def = optim.Adafactor(learning_rate=0.1,
                                    beta1=0.0, decay_rate=0.8,
                                    min_dim_size_to_factor=32)
    params = onp.ones((3, 2), onp.float32)
    state = optim.OptimizerState(
        1, _AdafactorParamState(onp.zeros(1,),
                                onp.zeros(1,),
                                0.5*onp.ones((3, 2)),
                                onp.zeros((3, 2))))
    grads = onp.ones((3, 2), onp.float32)
    new_params, new_state = optimizer_def.apply_gradient(
        optimizer_def.hyper_params, params, state, grads)
    expected_new_params = 0.9 * onp.ones((3, 2))
    onp.testing.assert_allclose(new_params, expected_new_params)
    expected_new_state = optim.OptimizerState(
        2, _AdafactorParamState(
            onp.array([0.0]),
            onp.array([0.0]),
            0.787174 * onp.ones((3, 2)),
            0.1 * onp.ones((3,2))))
    check_eq(new_state, expected_new_state, rtol=1e-6)
Пример #3
0
    def test_factorizes(self):
        params = np.zeros((64, 64))
        optimizer_def = optim.Adafactor(learning_rate=0.1,
                                        decay_rate=0.8,
                                        beta1=None,
                                        min_dim_size_to_factor=32)
        state = optimizer_def.init_state(params)
        self.assertEqual(state.param_states.v.shape, (1, ))
        self.assertEqual(state.param_states.m.shape, (1, ))
        self.assertEqual(state.param_states.v_row.shape, (64, ))
        self.assertEqual(state.param_states.v_col.shape, (64, ))

        params = np.zeros((31, 64))
        optimizer_def = optim.Adafactor(learning_rate=0.1,
                                        decay_rate=0.8,
                                        beta1=None,
                                        min_dim_size_to_factor=32)
        state = optimizer_def.init_state(params)
        self.assertEqual(state.param_states.v.shape, (31, 64))
        self.assertEqual(state.param_states.m.shape, (1, ))
        self.assertEqual(state.param_states.v_row.shape, (1, ))
        self.assertEqual(state.param_states.v_col.shape, (1, ))
Пример #4
0
def main(argv):
    global CFG
    CFG = FLAGS.config

    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16.
    _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16))

    # Use hardware RNG for bernoulli randoms in dropout mask creation.
    if CFG.hardware_rng:
        models.set_hardware_bernoulli()

    if 'module_import' in CFG and CFG.module_import:
        for module in CFG.module_import:
            importlib.import_module(module)

    if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs:
        t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs)

    num_partitions = CFG.num_partitions
    topology = train_lib.compute_multihost_topology(num_partitions)
    batch_size = CFG.batch_size
    eval_batch_size = CFG.eval_batch_size
    per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets
    if batch_size % topology.num_replicas:
        raise ValueError(
            'Batch size must be divisible by the number of replicas.')

    steps_per_epoch = CFG.steps_per_epoch
    logging.info('steps per epoch: %d', steps_per_epoch)

    broadcast = functools.partial(
        train_lib.broadcast,
        num_replicas=topology.per_replica_set_num_replicas,
        num_partitions=topology.per_host_num_partitions,
        devices=topology.this_host_device_assignment)

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        tf.io.gfile.copy(FLAGS['config'].config_filename,
                         os.path.join(FLAGS.model_dir, 'config.py'),
                         overwrite=True)
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    # Write summaries in background thread to avoid blocking on device sync
    if CFG.infeed:
        # Infeed is currently synchronous, so do it in a background thread too
        infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(),
                                                'infeed')

    (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache(
        CFG, topology.num_replica_sets, topology.replica_set_id,
        topology.per_replica_set_host_id)

    vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name)
    encoder = vocab.tf_tokenizer
    eos_id = vocab.tokenizer.eos_id()

    def decode_tokens(toks, eos_id=eos_id, max_id=32000):
        """Decode tokens back to unicode."""
        del eos_id
        # TODO(levskaya): T5 doesn't seem to emit EOS tokens?  double check this
        # is the best decoding function or just switch to using tf_decode.
        # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        valid_toks = toks.astype(np.int32)
        valid_toks[valid_toks >= max_id] = 3
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    train_config, eval_config, predict_config = get_configs(CFG)

    rng = random.PRNGKey(CFG.random_seed)
    rng, init_rng = random.split(rng)
    # This is used for infeed conversion from feature dict <--> tuple
    train_keys = [
        'inputs', 'targets', 'inputs_position', 'targets_position',
        'inputs_segmentation', 'targets_segmentation'
    ]
    device_train_input_shape = tuple([
        (batch_size // topology.num_replicas,
         CFG.max_input_length if 'inputs' in k else CFG.max_target_length)
        for k in train_keys
    ])

    learning_rate_fn = train_lib.create_learning_rate_scheduler(
        factors=CFG.schedule,
        base_learning_rate=CFG.learning_rate,
        warmup_steps=CFG.warmup_steps)

    # First, we only abstractly initialize the optimizer and model parameters,
    # since the parameters may not even fit in device memory!
    # TODO(jekbradbury): make optimizer_defs compare by value so it can be created
    # in get_initial_params without causing pytree incompatibility
    optimizer_def = optim.Adafactor(CFG.learning_rate,
                                    decay_rate=0.8,
                                    step_offset=CFG.step_offset)
    initialize_params_fn = functools.partial(get_initial_params,
                                             config=CFG,
                                             transformer_config=eval_config,
                                             optimizer_def=optimizer_def)
    optimizer = jax.eval_shape(initialize_params_fn, init_rng)
    # tuple-like pytree leaves for global_arg_shapes
    optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape),
                                    optimizer)

    # Build parameter partition annotations for preserving partitions from train
    # to eval.
    if num_partitions > 1:
        optimizer_partitions = optimizer.restore_state(
            partitions.set_partitions(num_partitions, optimizer.state_dict()))
        per_host_optimizer_partitions = optimizer.restore_state(
            partitions.set_partitions(topology.per_host_num_partitions,
                                      optimizer.state_dict()))

    # Restore unreplicated optimizer + model state from last checkpoint.
    # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore
    existing_checkpoint_found = False
    if CFG.restore_checkpoints:
        existing_checkpoint_found = train_lib.checkpoint_exists(
            FLAGS.model_dir)
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)

    # Import a pretrained-T5 checkpoint only if we didn't import a local
    # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.)
    # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore
    if CFG.restore_t5_checkpoint and not existing_checkpoint_found:
        optimizer = checkpoint_importer.restore_from_t5_checkpoint(
            optimizer, CFG.restore_t5_checkpoint)

    if CFG.restore_t5_checkpoint or existing_checkpoint_found:
        if num_partitions > 1:
            # Until checkpoint/restore is sharded, the restored checkpoint is global
            # and we need to slice each sharded parameter into the chunk containing
            # only the partitions that are present on this host.
            def per_host_chunk(x, spec):
                if spec is None or spec is x:  # unsharded or not a parameter
                    return x
                if spec[0] == 1:
                    dim_size = x.shape[1]
                elif spec[1] == 1:
                    dim_size = x.shape[0]
                else:
                    raise NotImplementedError()
                chunk_size = (dim_size * topology.per_host_num_partitions //
                              num_partitions)
                lower = topology.per_replica_set_host_id * chunk_size
                upper = (topology.per_replica_set_host_id + 1) * chunk_size
                if spec[0] == 1:
                    return x[:, lower:upper]
                else:
                    return x[lower:upper]

            optimizer = jax.tree_multimap(per_host_chunk, optimizer,
                                          optimizer_partitions)
    else:
        # If pretraining and no checkpoint imported, we jit the (sharded-) init
        # function to minimize fragmentation. We use the same pmap(sharded_jit)
        # setup as the training step/loop to initialize everything "in-place" and
        # avoid communication or OOM.
        if num_partitions > 1:
            initialize_params_fn = sharded_jit(
                initialize_params_fn,
                in_parts=None,
                local_in_parts=None,
                out_parts=optimizer_partitions,
                local_out_parts=per_host_optimizer_partitions,
                # devices=one_replica_device_assignment,
            )
            initialize_params_fn = jax.pmap(initialize_params_fn,
                                            'batch',
                                            in_axes=0,
                                            axis_size=topology.num_replicas,
                                            devices=topology.device_assignment)
            init_rng = broadcast(init_rng)
            optimizer = initialize_params_fn(init_rng)
            # We maintain the optimizer in unbroadcasted form (i.e. with no leading
            # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg
            # out_axes=None.
            optimizer = train_lib.unbroadcast(optimizer)
        else:
            optimizer = jax.jit(initialize_params_fn)(init_rng)

    # ---------------------------------------------------------------------------
    # Compile multidevice versions of train/eval/predict step and cache init fn.
    # ---------------------------------------------------------------------------

    # We can use either a single train-step for a host training loop:

    # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs)
    #  --> new_optimizer, metrics, new_dropout_rng
    def p_train_step(optimizer, batch, prev_metrics, dropout_rng):
        return train_lib.train_step(optimizer,
                                    batch,
                                    prev_metrics,
                                    dropout_rng,
                                    config=train_config,
                                    learning_rate_fn=learning_rate_fn,
                                    num_microbatches=CFG.microbatches,
                                    label_smoothing=CFG.label_smoothing,
                                    z_loss=CFG.z_loss,
                                    use_bfloat16=CFG.use_bfloat16)

    if num_partitions > 1:
        p_train_step = sharded_jit(
            p_train_step,
            in_parts=(optimizer_partitions, None, None, None),
            local_in_parts=(per_host_optimizer_partitions, None, None, None),
            out_parts=(optimizer_partitions, None, None),
            local_out_parts=(per_host_optimizer_partitions, None, None))
    # TODO(levskaya): the in_axes spec below might be wrong, double-check.
    p_train_step = jax.pmap(p_train_step,
                            axis_name='batch',
                            in_axes=(None, 0, 0, 0),
                            donate_argnums=(0, ),
                            global_arg_shapes=(optimizer_shapes, None, None,
                                               None),
                            axis_size=topology.num_replicas,
                            devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # OR, we use an on-device loop that feeds the training step via infeed queue.
    def device_train_loop_cond(args):
        """Stopping criterion for on-device loop."""
        _, _, _, _, step, epoch = args
        return step // steps_per_epoch == epoch

    def device_train_loop_body(args):
        """On-device loop body."""
        optimizer, dropout_rngs, metrics, token, step, epoch = args
        # Ordering input data from infeed requires threading a symbolic token
        # through the computation.
        input_data, token = lax.infeed(token,
                                       shape=tuple([
                                           jax.ShapedArray(s, jnp.int32)
                                           for s in device_train_input_shape
                                       ]))
        # Rebuild input dict from infeed data tuple.
        batch = {k: v for k, v in zip(train_keys, input_data)}
        # Run the train_step function and return the loop state.
        optimizer, metrics, dropout_rngs = train_lib.train_step(
            optimizer,
            batch,
            metrics,
            dropout_rngs,
            train_config,
            learning_rate_fn,
            num_microbatches=CFG.microbatches,
            label_smoothing=CFG.label_smoothing,
            z_loss=CFG.z_loss)
        step += 1
        return optimizer, dropout_rngs, metrics, token, step, epoch

    def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch):
        # Create symbolic token for threading infeed data.
        token = lax.create_token(step)
        # Run on-device loop.
        optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, dropout_rngs, metrics, token, step, epoch))
        return optimizer, dropout_rngs, metrics, step

    if num_partitions > 1:
        device_train_loop = sharded_jit(
            device_train_loop,
            in_parts=(optimizer_partitions, None, None, None, None),
            local_in_parts=(per_host_optimizer_partitions, None, None, None,
                            None),
            out_parts=(optimizer_partitions, None, None, None),
            local_out_parts=(per_host_optimizer_partitions, None, None, None))
    p_train_epoch = jax.pmap(device_train_loop,
                             axis_name='batch',
                             in_axes=(None, 0, 0, None, None),
                             donate_argnums=(0, ),
                             global_arg_shapes=(optimizer_shapes, None, None,
                                                None, None),
                             axis_size=topology.num_replicas,
                             devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # Reduction psum for metric data.

    def p_allreduce_metrics(x):
        return lax.psum(x, axis_name='batch')

    if num_partitions > 1:
        p_allreduce_metrics = sharded_jit(
            p_allreduce_metrics,
            in_parts=None,
            local_in_parts=None,
            out_parts=None,
            local_out_parts=None,
            num_partitions=num_partitions,
            local_num_partitions=topology.per_host_num_partitions)
    p_allreduce_metrics = jax.pmap(p_allreduce_metrics,
                                   axis_name='batch',
                                   global_arg_shapes=None,
                                   axis_size=topology.num_replicas,
                                   devices=topology.device_assignment)

    # Training evaluation computation.

    # eval_step(params, batch, config, label_smoothing=0.0) --> metrics
    def p_eval_step(params, batch):
        return train_lib.eval_step(params,
                                   batch,
                                   config=eval_config,
                                   label_smoothing=CFG.label_smoothing)

    if num_partitions > 1:
        p_eval_step = sharded_jit(
            p_eval_step,
            in_parts=(optimizer_partitions.target, None),
            local_in_parts=(per_host_optimizer_partitions.target, None),
            out_parts=None,
            local_out_parts=None)
    p_eval_step = jax.pmap(p_eval_step,
                           axis_name='batch',
                           in_axes=(None, 0),
                           global_arg_shapes=(optimizer_shapes.target, None),
                           axis_size=topology.num_replicas,
                           devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # Fast autoregressive decoding loop.
    # For inference and model evaluation.

    # predict_step(inputs, params,
    #              eos_id, max_decode_len, config, beam_size=4) --> beam_seqs
    def p_pred_step(inputs, params):
        return train_lib.predict_step(inputs, params, eos_id,
                                      CFG.max_eval_target_length,
                                      predict_config, CFG.beam_size)

    if num_partitions > 1:
        p_pred_step = sharded_jit(
            p_pred_step,
            in_parts=(None, optimizer_partitions.target),
            local_in_parts=(None, per_host_optimizer_partitions.target),
            out_parts=None,
            local_out_parts=None)
    p_pred_step = jax.pmap(p_pred_step,
                           axis_name='batch',
                           in_axes=(0, None),
                           global_arg_shapes=(None, optimizer_shapes.target),
                           axis_size=topology.num_replicas,
                           devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # ---------------------------------------------------------------------------
    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    # There should be a unique dropout key for each replica represented on this
    # host, but the key should be the same for the same replica on other hosts.
    # Again, this is what the replica set abstraction is for.
    dropout_rngs = random.split(random.fold_in(rng, topology.replica_set_id),
                                topology.per_replica_set_num_replicas)
    # restore step from last checkpoint
    host_step = int(optimizer.state.step)
    empty_metrics = broadcast({
        'loss': 0.0,
        'accuracy': 0.0,
        'learning_rate': 0.0,
        'denominator': 0.0
    })
    if CFG.infeed:
        # TODO(jekbradbury): support something like this for the Python-loop case
        logging.info(
            'Precompiling training loop and moving optimizer to device.')
        optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs,
                                                 empty_metrics,
                                                 jnp.array(0,
                                                           dtype=jnp.int32), 1)
        optimizer = train_lib.unbroadcast(optimizer)
        metrics['loss'].block_until_ready()

    logging.info('Starting training loop.')

    local_devices = jax.local_devices()
    device_step = broadcast(host_step)
    first_epoch = host_step // steps_per_epoch

    # Main Loop over "epochs".
    train_iter = train_ds.as_numpy_iterator()
    for epoch in range(first_epoch, first_epoch + CFG.num_epochs):
        metrics = empty_metrics

        # NOTE: 'optimizer' is unbroadcast by construction at initialization or
        # when loading a checkpoint. It is maintained in 'unbroadcast' state to
        # enable the XLA cross-replica sharding optimization.  The broadcasting is
        # handled automatically by the pmap'd functions that use it.

        # Gather all task evaluation metrics.
        logging.info('Evaluating tasks.')
        if epoch == first_epoch + 1:
            train_lib.sync_devices()
        for task in eval_cache.tasks:
            logging.info('Evaluating task %s', task.name)
            all_predicted, all_bs = [], []
            for pred_batch in eval_cache.preprocessed_examples[task.name]:
                # Handle final odd-sized batch by padding instead of dropping it.
                input_batch, unpadded_batch_size = train_lib.pad_batch_to_size(
                    pred_batch['inputs'], per_replica_set_eval_batch_size)
                all_bs.append(unpadded_batch_size)
                # Split batch dimensions for pmap.
                input_batch = jax.tree_map(
                    lambda x: x.reshape((topology.per_replica_set_num_replicas,
                                         -1) + x.shape[1:]), input_batch)
                # Run fast inference on batch.
                all_predicted.append(p_pred_step(input_batch,
                                                 optimizer.target))

            # Pad out the number of batches so each host has the same number.
            max_host_batch_number = np.max(
                eval_cache.preprocessed_batch_sizes[task.name])
            batch_shortfall = max_host_batch_number - len(all_predicted)
            if batch_shortfall > 0:
                # TODO(levskaya): Fix for case of entirely empty all_predicted.
                # To make sure the cross-host barriers work, we run the program the same
                # number of times on all hosts. The results of this call is ignored, and
                # the predictions are populated with zeros instead.
                p_pred_step(input_batch, optimizer.target)  # Dummy call.
                all_predicted.extend([jnp.zeros_like(all_predicted[0])] *
                                     batch_shortfall)
                all_bs.extend([0] * batch_shortfall)
            all_predicted = jnp.concatenate(all_predicted)
            all_bs = jnp.array(all_bs)

            # Collect all batches from across hosts and reverse sharding.
            all_predicted = train_lib.host_allgather(
                all_predicted, topology.num_replica_sets,
                topology.replica_set_id, topology.per_replica_set_host_id == 0)
            seqlength = all_predicted.shape[-1]
            total_examples = np.sum(
                train_lib.host_allgather(
                    all_bs, topology.num_replica_sets, topology.replica_set_id,
                    topology.per_replica_set_host_id == 0))
            del all_bs
            assert total_examples == len(eval_cache.examples[task.name]), (
                'Total number of batches incorrect for task %s.' % task.name)
            # De-shard the collected predicted tokens and remove padding.
            all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape(
                -1, seqlength)[:total_examples]

            # We now run the post-processing and metric-fns on a single host.
            if jax.host_id() == 0:
                assert eval_summary_writer
                raw_predictions = []
                for tokens in all_predicted:
                    raw_predictions.append(decode_tokens(tokens))

                # post-process predictions for metric fns
                predictions = [
                    task.postprocess_fn(p, example=ex) for p, ex in zip(
                        raw_predictions, eval_cache.examples[task.name])
                ]

                for metric_fn in task.metric_fns:
                    scores = metric_fn(eval_cache.targets[task.name],
                                       predictions)
                    for metric_name, metric_value in scores.items():
                        tag = f'eval/{task.name}/{metric_name}'
                        eval_summary_writer.scalar(tag, metric_value,
                                                   host_step)
                        logging.info('EVAL %s at step %d: %.3f', tag,
                                     host_step, metric_value)
                    eval_summary_writer.flush()

                # Save text samples for tensorboard.
                exemplars = ''
                for n in np.random.choice(np.arange(len(predictions)), 8):
                    tgt_txt = tf.compat.as_text(
                        eval_cache.examples[task.name][n]['targets_plaintext'])
                    pred_txt = raw_predictions[n]
                    exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n'
                                  f'target: {tgt_txt}\n\n'
                                  f'prediction: {pred_txt}\n\n')
                eval_summary_writer.text(f'{task.name} samples', exemplars,
                                         host_step)
                eval_summary_writer.flush()

        # Take an Xprof trace after the first loop has compiled everything.
        if epoch == first_epoch + 1:
            train_lib.sync_devices()

        # For on-device loop, we launch the computation before feeding data.
        logging.info('BEGIN Train loop.')
        if CFG.infeed:
            optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
                optimizer, dropout_rngs, metrics,
                train_lib.unbroadcast(device_step), epoch)
            optimizer = train_lib.unbroadcast(optimizer)

        # Epoch loop.
        while int(host_step // steps_per_epoch) == epoch:
            batch = next(train_iter)
            batch = jax.tree_map(
                lambda x: x.reshape(
                    (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
                batch)
            # Feed the on-device training loop.
            if CFG.infeed:
                for i, device in enumerate(local_devices):
                    # When using infeed to provide data to the computation, we're on our
                    # own for feeding the right values to the right devices. Each device
                    # should get the minibatch corresponding to its replica, a slice of
                    # the larger batch corresponding to the host's replica set.
                    if device.platform == 'tpu':
                        device_coords = (*device.coords, device.id % 2)
                    else:
                        device_coords = (device.host_id, i)
                    per_replica_set_device_coords = tuple(
                        dc % prsm for dc, prsm in zip(
                            device_coords, topology.per_replica_set_mesh))
                    per_replica_set_replica_coords = tuple(
                        prsdc // prm
                        for prsdc, prm in zip(per_replica_set_device_coords,
                                              topology.per_replica_mesh))
                    per_replica_set_replica_id = 0
                    for prsm, prm, prsrc in zip(
                            topology.per_replica_set_mesh,
                            topology.per_replica_mesh,
                            per_replica_set_replica_coords):
                        per_replica_set_replica_id = (
                            per_replica_set_replica_id * prsm // prm + prsrc)
                    input_tuple = tuple([
                        batch[k][per_replica_set_replica_id]
                        for k in train_keys
                    ])
                    # Safety check: infeed does not check shape or types but requires
                    # them to agree with on-device spec, otherwise the queue and program
                    # stalls.
                    tuple_shapes = jax.tree_map(jnp.shape, input_tuple)
                    tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple)
                    assert tuple_shapes == device_train_input_shape, (
                        'infeed shape error %s != %s' %
                        (tuple_shapes, device_train_input_shape))
                    assert tuple(set(tuple_dtypes)) == (jnp.int32,), \
                        ('infeed dtype error %s not all of type %s' % (
                            tuple_dtypes, jnp.int32))
                    infeed_pool.submit(
                        functools.partial(device.transfer_to_infeed,
                                          input_tuple))
            # Host training loop.
            else:
                optimizer, metrics, dropout_rngs = p_train_step(
                    optimizer, batch, metrics, dropout_rngs)
                optimizer = train_lib.unbroadcast(optimizer)
            host_step += 1
        logging.info('END Train loop.')

        # Maybe save a checkpoint on one host.
        if (CFG.save_checkpoints
                and epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1
                and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step)

        # Gather training metrics.
        metrics = p_allreduce_metrics(metrics)
        metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics)
        denominator = metrics.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics)  # pylint: disable=cell-var-from-loop
        logging.info('train in step: %s, %s', host_step, summary)
        if jax.host_id() == 0:
            assert train_summary_writer
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, host_step)
            train_summary_writer.flush()

        # Gather training evaluation metrics.
        logging.info('Gathering training evaluation metrics.')
        eval_metrics = []
        eval_iter = eval_ds.as_numpy_iterator()
        for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(
                lambda x: x.reshape(
                    (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
                eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        # average metrics across devices
        eval_metrics = p_allreduce_metrics(eval_metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        # average metrics across steps
        eval_metrics = jax.tree_map(np.sum, eval_metrics)
        eval_denominator = eval_metrics.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics)
        logging.info('eval in step: %s, %s', host_step, eval_summary)
        if jax.host_id() == 0:
            assert eval_summary_writer
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, host_step)
            eval_summary_writer.flush()

    # Wait until computations are done before exiting
    logging.info('Finished.')
    train_lib.sync_devices()
    # Shut down the infeed threadpool.
    if CFG.infeed:
        infeed_pool.shutdown()
Пример #5
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([0., 0.3, 500.,
                                            5.]), jnp.array([300., 3.]))

    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR), optim.GradientDescent(LR)),
        ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(
            LR, beta=0.9)),  # Different names.
        ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True),
         optim.Momentum(LR, beta=0.9, nesterov=True)),
        ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)),
        ('centered_rmsprop', alias.rmsprop(
            LR, centered=True), optim.RMSProp(LR, centered=True)),
        ('adam', alias.adam(LR), optim.Adam(LR)),
        ('adam_w', alias.adamw(LR, weight_decay=1e-4),
         optim.Adam(LR, weight_decay=1e-4)),  # Different name.
        (
            'adagrad',
            alias.adagrad(LR,
                          initial_accumulator_value=0.),  # Different default!
            optim.Adagrad(LR)),
        ('lamb', alias.lamb(LR), optim.LAMB(LR)),
        ('lars',
         alias.lars(LR,
                    weight_decay=.5,
                    trust_coefficient=0.003,
                    momentum=0.9,
                    eps=1e-3),
         optim.LARS(
             LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9,
             eps=1e-3)),
        ('adafactor',
         alias.adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2),
         optim.Adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2)),
    )
    def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

        # flax/optim
        flax_params = self.init_params
        flax_optimizer = flax_optimizer.create(flax_params)
        for _ in range(STEPS):
            flax_optimizer = flax_optimizer.apply_gradient(
                self.per_step_updates)
            flax_params = flax_optimizer.target

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)
        for _ in range(STEPS):
            updates, state = optax_optimizer.update(self.per_step_updates,
                                                    state, optax_params)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)