def main(_):
  tf.enable_v2_behavior()
  tf.keras.backend.set_learning_phase(1)

  train()
Exemple #2
0
            one_hot_inputs, vocab_size)
        inv_inputs = tf.argmax(one_hot_inv, axis=-1)
        inputs_inv_inputs = tf.math.floormod(inputs * inv_inputs, vocab_size)
        self.assertAllEqual(inputs_inv_inputs, np.ones((batch_size, length)))

    def testApproximatelyStochastic(self):
        rng = np.random.RandomState(0)
        tf.random.set_seed(1)
        for dims in [2, 5, 10]:
            for batch_size in [1, 2, 10]:
                log_alpha = rng.randn(batch_size, dims, dims)
                result = ed.layers.utils.sinkhorn(log_alpha)
                self.assertAllClose(np.sum(result, 1),
                                    np.tile([1.0], (batch_size, dims)),
                                    atol=1e-3)
                self.assertAllClose(np.sum(result, 2),
                                    np.tile([1.0], (batch_size, dims)),
                                    atol=1e-3)

    def testSoftToHardPermutation(self):
        """The solution of the matching for the identity matrix is range(N)."""
        dims = 10
        identity = tf.eye(dims)
        result_matching = ed.layers.utils.soft_to_hard_permutation(identity)
        self.assertAllEqual(result_matching[0], np.eye(dims))


if __name__ == '__main__':
    tf.enable_v2_behavior()
    tf.test.main()
Exemple #3
0
def main(_):
    tf.enable_v2_behavior()
    ##############################################################################
    ######################### Data loading and processing ########################
    ##############################################################################
    print('Loading data')

    with gfile.GFile(_TRANSITION_PATH, 'r') as f:
        transitions = np.load(f)
    if np.max(transitions) > 1.0:
        transitions = transitions / 255.0
    with gfile.GFile(_SYNTHETIC_TRANSITION_PATH, 'r') as f:
        synthetic_tran_sitions = np.load(f)
    if np.max(synthetic_transitions) > 1.0:
        synthetic_transitions = synthetic_transitions / 255.0

    with gfile.GFile(transition_label_path, 'r') as f:
        captions = pickle.load(f)
    with gfile.GFile(_SYNTHETIC_TRANSITION_LABEL_PATH, 'r') as f:
        synthetic_captions = pickle.load(f)

    with gfile.GFile(vocab_path, 'r') as f:
        vocab_list = f.readlines()

    vocab_list = [w[:-1].decode('utf-8') for w in vocab_list]
    vocab_list = ['eos', 'sos'] + vocab_list

    v2i, i2v = wv.create_look_up_table(vocab_list)
    encode_fn = wv.encode_text_with_lookup_table(v2i)
    decode_fn = wv.decode_with_lookup_table(i2v)

    encoded_captions = []
    for all_cp in captions:
        for cp in all_cp:
            cp = 'sos ' + cp + ' eos'
            encoded_captions.append(np.array(encode_fn(cp)))

    synthetic_encoded_captions = []
    for all_cp in synthetic_captions:
        for cp in all_cp:
            cp = 'sos ' + cp + ' eos'
            synthetic_encoded_captions.append(np.array(encode_fn(cp)))

    all_caption_n = len(encoded_captions)
    all_synthetic_caption_n = len(synthetic_encoded_captions)

    encoded_captions = np.array(encoded_captions)
    encoded_captions = pad_to_max_length(encoded_captions, max_l=15)

    synthetic_encoded_captions = np.array(synthetic_encoded_captions)
    synthetic_encoded_captions = pad_to_max_length(synthetic_encoded_captions,
                                                   max_l=15)

    obs_idx, caption_idx = [], []
    curr_caption_idx = 0
    for i, _ in enumerate(transitions):
        for cp in captions[i]:
            obs_idx.append(i)
            caption_idx.append(curr_caption_idx)
            curr_caption_idx += 1
    assert curr_caption_idx == all_caption_n

    synthetic_obs_idx, synthetic_caption_idx = [], []
    curr_caption_idx = 0
    for i, _ in enumerate(synthetic_transitions):
        for cp in synthetic_captions[i]:
            synthetic_obs_idx.append(i)
            synthetic_caption_idx.append(curr_caption_idx)
            curr_caption_idx += 1
    assert curr_caption_idx == all_synthetic_caption_n

    obs_idx = np.array(obs_idx)
    caption_idx = np.array(caption_idx)
    all_idx = np.arange(len(caption_idx))
    train_idx = all_idx[:int(len(all_idx) * 0.8)]
    test_idx = all_idx[int(len(all_idx) * 0.8):]
    print('Number of training examples: {}'.format(len(train_idx)))
    print('Number of test examples: {}\n'.format(len(test_idx)))

    synthetic_obs_idx = np.array(synthetic_obs_idx)
    synthetic_caption_idx = np.array(synthetic_caption_idx)
    synthetic_all_idx = np.arange(len(synthetic_caption_idx))
    synthetic_train_idx = synthetic_all_idx[:int(len(synthetic_all_idx) * 0.8)]
    synthetic_test_idx = synthetic_all_idx[int(len(synthetic_all_idx) * 0.8):]
    print('Number of synthetic training examples: {}'.format(
        len(synthetic_train_idx)))
    print('Number of synthetic test examples: {}\n'.format(
        len(synthetic_test_idx)))

    ##############################################################################
    ############################# Training Setup #################################
    ##############################################################################
    embedding_dim = 32
    units = 64
    vocab_size = len(vocab_list)
    batch_size = 64
    max_sequence_length = 15

    encoder_config = {'name': 'image', 'embedding_dim': 32}
    decoder_config = {
        'name': 'attention',
        'word_embedding_dim': 64,
        'hidden_units': 256,
        'vocab_size': len(vocab_list),
    }

    encoder = get_captioning_encoder(encoder_config)
    decoder = get_captioning_decoder(decoder_config)

    optimizer = tf.keras.optimizers.Adam()
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')

    def loss_function(real, pred, sos_symbol=1):
        mask = tf.math.logical_not(tf.math.equal(real, sos_symbol))
        loss_ = loss_object(real, pred)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        return tf.reduce_mean(loss_)

    @tf.function
    def train_step(input_tensor, target):
        """Traing on a batch of data."""
        loss = 0
        # initializing the hidden state for each batch
        # because the captions are not related from image to image
        hidden = decoder.reset_state(batch_size=target.shape[0])

        dec_input = tf.expand_dims([1] * target.shape[0], 1)

        with tf.GradientTape() as tape:
            features = encoder(input_tensor, training=True)
            for i in range(1, target.shape[1]):
                # passing the features through the decoder
                predictions, hidden, _ = decoder(dec_input,
                                                 features,
                                                 hidden,
                                                 training=True)
                loss += loss_function(target[:, i], predictions)
                # using teacher forcing
                dec_input = tf.expand_dims(target[:, i], 1)

        total_loss = (loss / int(target.shape[1]))
        trainable_variables = encoder.trainable_variables + decoder.trainable_variables
        gradients = tape.gradient(loss, trainable_variables)
        optimizer.apply_gradients(zip(gradients, trainable_variables))

        return loss, total_loss

    @tf.function
    def evaluate_batch(input_tensor, target):
        """Evaluate loss on a batch of data."""
        loss = 0
        # initializing the hidden state for each batch
        # because the captions are not related from image to image
        hidden = decoder.reset_state(batch_size=target.shape[0])
        dec_input = tf.expand_dims([1] * target.shape[0], 1)
        features = encoder(input_tensor, training=False)

        for i in range(1, target.shape[1]):
            # passing the features through the decoder
            predictions, hidden, _ = decoder(dec_input,
                                             features,
                                             hidden,
                                             training=False)
            loss += loss_function(target[:, i], predictions)
            # using teacher forcing
            dec_input = tf.expand_dims(target[:, i], 1)
        total_loss = (loss / int(target.shape[1]))
        return total_loss

    ##############################################################################
    ############################# Training Loop ##################################
    ##############################################################################
    print('Start training...\n')
    start_epoch = 0
    if FLAGS.save_dir:
        checkpoint_path = FLAGS.save_dir
        ckpt = tf.train.Checkpoint(encoder=encoder,
                                   decoder=decoder,
                                   optimizer=optimizer)
        ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                  checkpoint_path,
                                                  max_to_keep=5)
        if ckpt_manager.latest_checkpoint:
            start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])

    epochs = 400
    step_per_epoch = int(len(captions) / batch_size) * 10

    previous_best = 100.

    mixing_ratio = 0.4
    syn_bs = int(batch_size * 2 * mixing_ratio)
    true_bs = int(batch_size * 2 * (1 - mixing_ratio))

    for epoch in range(start_epoch, epochs):
        start = time.time()
        total_loss = 0

        for batch in range(step_per_epoch):
            batch_idx = np.random.choice(train_idx, size=true_bs)
            synthetic_batch_idx = np.random.choice(synthetic_train_idx,
                                                   size=syn_bs)
            input_tensor = transitions[obs_idx[batch_idx], :]
            synthetic_input_tensor = synthetic_transitions[
                synthetic_obs_idx[synthetic_batch_idx], :]
            input_tensor = np.concatenate(
                [input_tensor, synthetic_input_tensor], axis=0)
            input_tensor = encoder.preprocess(input_tensor)
            target = encoded_captions[caption_idx[batch_idx]]
            sythetic_target = synthetic_encoded_captions[
                synthetic_caption_idx[synthetic_batch_idx]]
            target = np.concatenate([target, sythetic_target], axis=0)
            batch_loss, t_loss = train_step(input_tensor, target)
            total_loss += t_loss

            if batch % 100 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(
                    epoch + 1, batch,
                    batch_loss.numpy() / int(target.shape[1])))

        if epoch % 5 == 0 and FLAGS.save_dir:
            test_total_loss = 0
            for batch in range(3):
                batch_idx = np.clip(
                    np.arange(true_bs) + batch * true_bs, 0, 196)
                idx = test_idx[batch_idx]
                input_tensor = transitions[obs_idx[idx], :]
                target = encoded_captions[caption_idx[idx]]
                t_loss = evaluate_batch(input_tensor, target)
                test_total_loss += t_loss
                batch_idx = np.arange(syn_bs) + batch * syn_bs
                idx = synthetic_test_idx[batch_idx]
                input_tensor = synthetic_transitions[synthetic_obs_idx[idx], :]
                target = synthetic_encoded_captions[synthetic_caption_idx[idx]]
                t_loss = evaluate_batch(input_tensor, target)
                test_total_loss += t_loss
            test_total_loss /= 6.
            if test_total_loss < previous_best:
                previous_best = test_total_loss
                ckpt_manager.save(checkpoint_number=epoch)

        print('Epoch {} | Loss {:.6f} | Val loss {:.6f}'.format(
            epoch + 1, total_loss / step_per_epoch, previous_best))
        print('Time taken for 1 epoch {:.6f} sec\n'.format(time.time() -
                                                           start))

        if epoch % 20 == 0:
            total_loss = 0
            for batch in range(len(test_idx) // batch_size):
                batch_idx = np.arange(batch_size) + batch * batch_size
                idx = test_idx[batch_idx]
                input_tensor = transitions[obs_idx[idx], :]
                target = encoded_captions[caption_idx[idx]]
                # input_tensor = input_tensor[:, 0] - input_tensor[:, 1]
                t_loss = evaluate_batch(input_tensor, target)
                total_loss += t_loss

            print('====================================================')
            print('Test Loss {:.6f}'.format(total_loss /
                                            (len(test_idx) // batch_size)))
            print('====================================================\n')
Exemple #4
0
def main(_):
    tf.enable_v2_behavior()
    visualize_tfrecords(FLAGS.path_to_tfrecord, FLAGS.num_vids,
                        FLAGS.num_skip_frames)
Exemple #5
0
def main(argv):
    del argv  # unused arg
    tf.enable_v2_behavior()

    dataset_train, ds_info = utils.load_dataset(tfds.Split.TRAIN,
                                                with_info=True)
    dataset_test = utils.load_dataset(tfds.Split.TEST)
    dataset_train = dataset_train.batch(FLAGS.batch_size)
    dataset_test = dataset_test.batch(FLAGS.batch_size)

    model = deterministic.resnet_v1(
        input_shape=ds_info.features['image'].shape,
        depth=20,
        num_classes=ds_info.features['label'].num_classes,
        l2=0.)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.output_dir, '**/*.ckpt.index'))
    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))

    # Collect the logits output for each ensemble member and train/test data
    # point. We also collect the labels.
    # TODO(trandustin): Refactor data loader so you can get the full dataset in
    # memory without looping.
    logits_train = []
    logits_test = []
    labels_train = []
    labels_test = []
    for m, ensemble_filename in enumerate(ensemble_filenames):
        model.load_weights(ensemble_filename)
        logits = []
        for features, labels in dataset_train:
            logits.append(model(features, training=False))
            if m == 0:
                labels_train.append(labels)

        logits = tf.concat(logits, axis=0)
        logits_train.append(logits)
        if m == 0:
            labels_train = tf.concat(labels_train, axis=0)

        logits = []
        for features, labels in dataset_test:
            logits.append(model(features, training=False))
            if m == 0:
                labels_test.append(labels)

        logits = tf.concat(logits, axis=0)
        logits_test.append(logits)
        if m == 0:
            labels_test = tf.concat(labels_test, axis=0)
        logging.info('Predictions completed for checkpoint %s',
                     ensemble_filename)

    metrics = {}

    # Compute the ensemble's NLL and Gibbs cross entropy for each data point.
    # Then average over the dataset.
    nll_train = ensemble_negative_log_likelihood(labels_train, logits_train)
    nll_test = ensemble_negative_log_likelihood(labels_test, logits_test)
    gibbs_ce_train = gibbs_cross_entropy(labels_train, logits_train)
    gibbs_ce_test = gibbs_cross_entropy(labels_test, logits_test)
    metrics['train_nll'] = tf.reduce_mean(nll_train)
    metrics['test_nll'] = tf.reduce_mean(nll_test)
    metrics['train_gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_train)
    metrics['test_gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test)

    # Given the per-element logits tensor of shape [ensemble_size, dataset_size,
    # num_classes], average over the ensemble members' probabilities. Then
    # compute accuracy and average over the dataset.
    probs_train = tf.reduce_mean(tf.nn.softmax(logits_train), axis=0)
    probs_test = tf.reduce_mean(tf.nn.softmax(logits_test), axis=0)
    accuracy_train = tf.keras.metrics.sparse_categorical_accuracy(
        labels_train, probs_train)
    accuracy_test = tf.keras.metrics.sparse_categorical_accuracy(
        labels_test, probs_test)
    metrics['train_accuracy'] = tf.reduce_mean(accuracy_train)
    metrics['test_accuracy'] = tf.reduce_mean(accuracy_test)
    logging.info('Metrics: %s', metrics)
Exemple #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config = FLAGS.config
    logging.info('===========Config Dict============')
    logging.info(config)
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    num_train_steps = config.num_train_steps
    num_eval_steps = config.num_eval_steps
    eval_freq = config.eval_frequency
    random_seed = config.random_seed
    model_type = config.model_type

    max_length = config.max_length

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'summary'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_ds, eval_ds, test_ds, encoder = input_pipeline.get_matching_datasets(
        n_devices=jax.local_device_count(),
        task_name=FLAGS.task_name,
        data_dir=FLAGS.data_dir,
        batch_size=batch_size,
        fixed_vocab=None,
        max_length=max_length,
        tokenizer=config.tokenizer,
        vocab_file_path=FLAGS.vocab_file_path)

    vocab_size = encoder.vocab_size
    logging.info('Vocab Size: %d', vocab_size)

    train_ds = train_ds.repeat()

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_length)

    model_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': config.emb_dim,
        'num_heads': config.num_heads,
        'num_layers': config.num_layers,
        'qkv_dim': config.qkv_dim,
        'mlp_dim': config.mlp_dim,
        'max_len': max_length,
        'classifier': True,
        'num_classes': 2,
        'classifier_pool': config.pooling_mode
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    if model_type == 'transformer':
        model = create_model(init_rng, transformer.TransformerDualEncoder,
                             input_shape, input_shape, model_kwargs)
    else:
        raise ValueError('Model type not supported.')

    optimizer = create_optimizer(model,
                                 learning_rate,
                                 weight_decay=FLAGS.config.weight_decay)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if config.restore_checkpoints or FLAGS.test_only:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors=config.factors,
        base_learning_rate=learning_rate,
        warmup_steps=config.warmup)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # p_pred_step = jax.pmap(predict_step, axis_name='batch')

    def run_eval(eval_ds, num_eval_steps=-1):
        eval_metrics = []
        eval_iter = iter(eval_ds)
        if num_eval_steps == -1:
            num_iter = itertools.count()
        else:
            num_iter = range(num_eval_steps)
        for _, eval_batch in zip(num_iter, eval_iter):
            # pylint: disable=protected-access
            eval_batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), eval_batch))
            # pylint: enable=protected-access
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        # Calculate (clipped) perplexity after averaging log-perplexities:
        eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']),
                                              a_max=1.0e4)
        return eval_summary

    if FLAGS.test_only:
        with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.json'),
                               'w') as f:
            test_summary = run_eval(test_ds)
            json.dump(jax.tree_map(lambda x: x.tolist(), test_summary), f)
        return

    metrics_all = []
    tick = time.time()
    logging.info('Starting training')
    logging.info('====================')

    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        # logging.info(batch)
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)
        logging.info('train in step: %d', step)

        # Save a Checkpoint
        if ((step % config.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('steps per second', steps_per_sec, step)
                for key, val in summary.items():
                    summary_writer.scalar(f'train_{key}', val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_summary = run_eval(eval_ds, num_eval_steps)
            logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    summary_writer.scalar(f'eval_{key}', val, step)
                summary_writer.flush()

            # Test eval
            # Eval Metrics
            logging.info('Testing...')
            test_summary = run_eval(test_ds, num_eval_steps)
            logging.info('test in step: %d, loss: %.4f, acc: %.4f', step,
                         test_summary['loss'], test_summary['accuracy'])
            if jax.host_id() == 0:
                for key, val in test_summary.items():
                    summary_writer.scalar(f'test_{key}', val, step)
                summary_writer.flush()
Exemple #7
0
def main(argv):
    del argv  # unused arg
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)

    dataset_test = utils.ImageNetInput(is_training=False,
                                       data_dir=FLAGS.data_dir,
                                       batch_size=FLAGS.per_core_batch_size,
                                       use_bfloat16=False).input_fn()
    test_datasets = {'clean': dataset_test}

    model = deterministic_model.resnet50(input_shape=(224, 224, 3),
                                         num_classes=NUM_CLASSES)

    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())
    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.output_dir, '**/*.index'))
    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Collect the logits output for each ensemble member and test data
    # point. We also collect the labels.

    logits_test = {'clean': []}
    labels_test = {'clean': []}
    corruption_types, max_intensity = utils.load_corrupted_test_info()
    for name in corruption_types:
        for intensity in range(1, max_intensity + 1):
            dataset_name = '{0}_{1}'.format(name, intensity)
            logits_test[dataset_name] = []
            labels_test[dataset_name] = []

            test_datasets[dataset_name] = utils.load_corrupted_test_dataset(
                name=name,
                intensity=intensity,
                batch_size=FLAGS.per_core_batch_size,
                drop_remainder=True,
                use_bfloat16=False)

    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename)
        logging.info('Working on test data for ensemble member %s', m)
        for name, test_dataset in test_datasets.items():
            logits = []
            for features, labels in test_dataset:
                logits.append(model(features, training=False))
                if m == 0:
                    labels_test[name].append(labels)

            logits = tf.concat(logits, axis=0)
            logits_test[name].append(logits)
            if m == 0:
                labels_test[name] = tf.concat(labels_test[name], axis=0)
            logging.info('Finished testing on %s', format(name))

    metrics = {
        'test/ece':
        ed.metrics.ExpectedCalibrationError(num_classes=NUM_CLASSES,
                                            num_bins=15)
    }
    corrupt_metrics = {}
    for name in test_datasets:
        corrupt_metrics['test/ece_{}'.format(
            name)] = ed.metrics.ExpectedCalibrationError(
                num_classes=NUM_CLASSES, num_bins=15)
        corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean()
        corrupt_metrics['test/accuracy_{}'.format(
            name)] = tf.keras.metrics.Mean()

    for name, test_dataset in test_datasets.items():
        labels = labels_test[name]
        logits = logits_test[name]
        nll_test = ensemble_negative_log_likelihood(labels, logits)
        gibbs_ce_test = gibbs_cross_entropy(labels_test[name],
                                            logits_test[name])
        labels = tf.cast(labels, tf.int32)
        logits = tf.convert_to_tensor(logits)
        per_probs = tf.nn.softmax(logits)
        probs = tf.reduce_mean(per_probs, axis=0)
        accuracy = tf.keras.metrics.sparse_categorical_accuracy(labels, probs)
        if name == 'clean':
            metrics['test/negative_log_likelihood'] = tf.reduce_mean(nll_test)
            metrics['test/gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test)
            metrics['test/accuracy'] = tf.reduce_mean(accuracy)
            metrics['test/ece'].update_state(labels, probs)
        else:
            corrupt_metrics['test/nll_{}'.format(name)].update_state(
                tf.reduce_mean(nll_test))
            corrupt_metrics['test/accuracy_{}'.format(name)].update_state(
                tf.reduce_mean(accuracy))
            corrupt_metrics['test/ece_{}'.format(name)].update_state(
                labels, probs)

    corrupt_results = {}
    corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                      corruption_types,
                                                      max_intensity)
    metrics['test/ece'] = metrics['test/ece'].result()
    total_results = {name: metric for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    logging.info('Metrics: %s', total_results)
Exemple #8
0
def main(argv):
  del argv  # unused arg
  if not FLAGS.use_gpu:
    raise ValueError('Only GPU is currently supported.')
  if FLAGS.num_cores > 1:
    raise ValueError('Only a single accelerator is currently supported.')
  tf.enable_v2_behavior()
  tf.random.set_seed(FLAGS.seed)
  tf.io.gfile.makedirs(FLAGS.output_dir)

  batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
  steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size

  dataset_test = utils.ImageNetInput(
      is_training=False,
      data_dir=FLAGS.data_dir,
      batch_size=FLAGS.per_core_batch_size,
      use_bfloat16=False).input_fn()
  test_datasets = {'clean': dataset_test}
  corruption_types, max_intensity = utils.load_corrupted_test_info()
  for name in corruption_types:
    for intensity in range(1, max_intensity + 1):
      dataset_name = '{0}_{1}'.format(name, intensity)
      test_datasets[dataset_name] = utils.load_corrupted_test_dataset(
          name=name,
          intensity=intensity,
          batch_size=FLAGS.per_core_batch_size,
          drop_remainder=True,
          use_bfloat16=False)

  model = deterministic_model.resnet50(input_shape=(224, 224, 3),
                                       num_classes=NUM_CLASSES)

  logging.info('Model input shape: %s', model.input_shape)
  logging.info('Model output shape: %s', model.output_shape)
  logging.info('Model number of weights: %s', model.count_params())
  # Search for checkpoints from their index file; then remove the index suffix.
  ensemble_filenames = tf.io.gfile.glob(os.path.join(FLAGS.checkpoint_dir,
                                                     '**/*.index'))
  ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
  ensemble_size = len(ensemble_filenames)
  logging.info('Ensemble size: %s', ensemble_size)
  logging.info('Ensemble number of weights: %s',
               ensemble_size * model.count_params())
  logging.info('Ensemble filenames: %s', str(ensemble_filenames))
  checkpoint = tf.train.Checkpoint(model=model)

  # Write model predictions to files.
  num_datasets = len(test_datasets)
  for m, ensemble_filename in enumerate(ensemble_filenames):
    checkpoint.restore(ensemble_filename)
    for n, (name, test_dataset) in enumerate(test_datasets.items()):
      filename = '{dataset}_{member}.npy'.format(dataset=name, member=m)
      filename = os.path.join(FLAGS.output_dir, filename)
      if not tf.io.gfile.exists(filename):
        logits = []
        test_iterator = iter(test_dataset)
        for _ in range(steps_per_eval):
          features, _ = next(test_iterator)  # pytype: disable=attribute-error
          logits.append(model(features, training=False))

        logits = tf.concat(logits, axis=0)
        with tf.io.gfile.GFile(filename, 'w') as f:
          np.save(f, logits.numpy())
      percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets)
      message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. '
                 'Dataset {:d}/{:d}'.format(percent,
                                            m + 1,
                                            ensemble_size,
                                            n + 1,
                                            num_datasets))
      logging.info(message)

  metrics = {
      'test/negative_log_likelihood': tf.keras.metrics.Mean(),
      'test/gibbs_cross_entropy': tf.keras.metrics.Mean(),
      'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
      'test/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
  }
  corrupt_metrics = {}
  for name in test_datasets:
    corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean()
    corrupt_metrics['test/accuracy_{}'.format(name)] = (
        tf.keras.metrics.SparseCategoricalAccuracy())
    corrupt_metrics['test/ece_{}'.format(
        name)] = ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)

  # Evaluate model predictions.
  for n, (name, test_dataset) in enumerate(test_datasets.items()):
    logits_dataset = []
    for m in range(ensemble_size):
      filename = '{dataset}_{member}.npy'.format(dataset=name, member=m)
      filename = os.path.join(FLAGS.output_dir, filename)
      with tf.io.gfile.GFile(filename, 'rb') as f:
        logits_dataset.append(np.load(f))

    logits_dataset = tf.convert_to_tensor(logits_dataset)
    test_iterator = iter(test_dataset)
    for step in range(steps_per_eval):
      _, labels = next(test_iterator)  # pytype: disable=attribute-error
      logits = logits_dataset[:, (step*batch_size):((step+1)*batch_size)]
      labels = tf.cast(tf.reshape(labels, [-1]), tf.int32)
      negative_log_likelihood = tf.reduce_mean(
          ensemble_negative_log_likelihood(labels, logits))
      per_probs = tf.nn.softmax(logits)
      probs = tf.reduce_mean(per_probs, axis=0)
      if name == 'clean':
        gibbs_ce = tf.reduce_mean(gibbs_cross_entropy(labels, logits))
        metrics['test/negative_log_likelihood'].update_state(
            negative_log_likelihood)
        metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce)
        metrics['test/accuracy'].update_state(labels, probs)
        metrics['test/ece'].update_state(labels, probs)
      else:
        corrupt_metrics['test/nll_{}'.format(name)].update_state(
            negative_log_likelihood)
        corrupt_metrics['test/accuracy_{}'.format(name)].update_state(
            labels, probs)
        corrupt_metrics['test/ece_{}'.format(name)].update_state(
            labels, probs)

    message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format(
        (n + 1) / num_datasets, n + 1, num_datasets))
    logging.info(message)

  corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                    corruption_types,
                                                    max_intensity,
                                                    FLAGS.alexnet_errors_path)
  total_results = {name: metric.result() for name, metric in metrics.items()}
  total_results.update(corrupt_results)
  logging.info('Metrics: %s', total_results)
Exemple #9
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # This seems to be necessary even when importing TF2?
    tf.enable_v2_behavior()

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = random.split(rng)
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    # call a jitted initialization function to get the initial parameter tree
    @jax.jit
    def initialize_variables(rng):
        return models.Transformer(eval_config).init(
            rng, jnp.ones(input_shape, jnp.float32),
            jnp.ones(target_shape, jnp.float32))

    initial_variables = initialize_variables(init_rng)

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['param'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing),
                            axis_name='batch',
                            donate_argnums=(0, ))
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=FLAGS.label_smoothing),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_predict_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info('Starting training loop.')
    metrics_all = []
    t_loop_start = time.time()
    for step, batch in zip(range(start_step, FLAGS.num_train_steps),
                           train_iter):
        # Shard data to devices and do a training step.
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a checkpoint on one host after every checkpoint_freq steps.
        if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0
                and step > 0 and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir,
                                        jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if step % FLAGS.eval_frequency != 0 and step > 0:
            continue

        # Training Metrics
        logging.info('Gathering training metrics.')
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        steps_per_eval = FLAGS.eval_frequency if step != 0 else 1
        steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
        t_loop_start = time.time()
        if jax.host_id() == 0:
            train_summary_writer.scalar('steps per second', steps_per_sec,
                                        step)
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        metrics_all = []
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])

        # Eval Metrics
        logging.info('Gathering evaluation metrics.')
        t_eval_start = time.time()
        eval_metrics = []
        eval_iter = iter(eval_ds)
        for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            eval_batch = common_utils.shard(eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        if jax.host_id() == 0:
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
        logging.info('eval in step: %d, loss: %.4f', step,
                     eval_summary['loss'])
        logging.info('eval time: %.4f s step %d',
                     time.time() - t_eval_start, step)

        # Translation and BLEU Score.
        logging.info('Translating evaluation dataset.')
        t_inference_start = time.time()
        predict_iter = iter(predict_ds)
        sources, references, predictions = [], [], []
        for _, pred_batch in enumerate(predict_iter):
            pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
            pred_batch = common_utils.shard(pred_batch)
            cache = p_init_cache(pred_batch['inputs'])
            predicted = p_pred_step(pred_batch['inputs'], optimizer.target,
                                    cache, eos_id, FLAGS.max_predict_length)
            predicted = tohost(predicted)
            inputs = tohost(pred_batch['inputs'])
            targets = tohost(pred_batch['targets'])
            # Iterate through non-padding examples of batch.
            for i, s in enumerate(predicted[:cur_pred_batch_size]):
                sources.append(decode_tokens(inputs[i]))
                references.append(decode_tokens(targets[i]))
                predictions.append(decode_tokens(s))
        logging.info('Translation: %d predictions %d references %d sources.',
                     len(predictions), len(references), len(sources))
        logging.info('Translation time: %.4f s step %d.',
                     time.time() - t_inference_start, step)

        # Calculate BLEU score for translated eval corpus against reference.
        bleu_matches = bleu.bleu_partial(references, predictions)
        all_bleu_matches = per_host_sum_pmap(bleu_matches)
        bleu_score = bleu.complete_bleu(*all_bleu_matches)
        # Save translation samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
            exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n'
        if jax.host_id() == 0:
            eval_summary_writer.scalar('bleu', bleu_score, step)
            eval_summary_writer.text('samples', exemplars, step)
            eval_summary_writer.flush()
        logging.info('Translation BLEU Score %.4f', bleu_score)
Exemple #10
0
def main(argv):
    del argv  # unused arg
    tf.enable_v2_behavior()
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    def train_input_fn(ctx):
        """Sets up local (per-core) dataset batching."""
        dataset = utils.load_distributed_dataset(
            split=tfds.Split.TRAIN,
            name=FLAGS.dataset,
            batch_size=FLAGS.per_core_batch_size // FLAGS.num_models,
            drop_remainder=True,
            use_bfloat16=FLAGS.use_bfloat16,
            proportion=FLAGS.train_proportion)
        if ctx and ctx.num_input_pipelines > 1:
            dataset = dataset.shard(ctx.num_input_pipelines,
                                    ctx.input_pipeline_id)
        return dataset

    # No matter what percentage of training proportion, we still evaluate the
    # model on the full test dataset.
    def test_input_fn(ctx):
        """Sets up local (per-core) dataset batching."""
        dataset = utils.load_distributed_dataset(
            split=tfds.Split.TEST,
            name=FLAGS.dataset,
            batch_size=FLAGS.per_core_batch_size // FLAGS.num_models,
            drop_remainder=True,
            use_bfloat16=FLAGS.use_bfloat16)
        if ctx and ctx.num_input_pipelines > 1:
            dataset = dataset.shard(ctx.num_input_pipelines,
                                    ctx.input_pipeline_id)
        return dataset

    train_dataset = strategy.experimental_distribute_datasets_from_function(
        train_input_fn)
    test_dataset = strategy.experimental_distribute_datasets_from_function(
        test_input_fn)
    ds_info = tfds.builder(FLAGS.dataset).info

    batch_size = ((FLAGS.per_core_batch_size // FLAGS.num_models) *
                  FLAGS.num_cores)
    # Train_proportion is a float so need to convert steps_per_epoch to int.
    steps_per_epoch = int(
        (ds_info.splits['train'].num_examples * FLAGS.train_proportion) //
        batch_size)
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    with strategy.scope():
        logging.info('Building Keras ResNet-32 model')
        model = batchensemble_model.ensemble_resnet_v1(
            input_shape=ds_info.features['image'].shape,
            depth=32,
            num_classes=ds_info.features['label'].num_classes,
            width_multiplier=4,
            num_models=FLAGS.num_models,
            random_sign_init=FLAGS.random_sign_init,
            dropout_rate=FLAGS.dropout_rate,
            l2=FLAGS.l2)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_schedule = utils.ResnetLearningRateSchedule(steps_per_epoch,
                                                       base_lr, _LR_SCHEDULE)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)
        train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        train_nll = tf.keras.metrics.Mean('train_nll', dtype=tf.float32)
        train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'train_accuracy', dtype=tf.float32)
        test_nll = tf.keras.metrics.Mean('test_nll', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)
        test_nlls = []
        test_accs = []
        for i in range(FLAGS.num_models):
            test_nlls.append(
                tf.keras.metrics.Mean('test_nll_{}'.format(i),
                                      dtype=tf.float32))
            test_accs.append(
                tf.keras.metrics.SparseCategoricalAccuracy(
                    'test_accuracy_{}'.format(i), dtype=tf.float32))

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries/'))

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            if FLAGS.version2:
                images = tf.tile(images, [FLAGS.num_models, 1, 1, 1])
                labels = tf.tile(labels, [FLAGS.num_models])

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate implementation.
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = []
                for grad, var in zip(grads, model.trainable_variables):
                    # Apply different learning rate on the fast weight approximate
                    # posterior/prior parameters. This is excludes BN and slow weights,
                    # but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grads_and_vars.append(
                            (grad * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grads_and_vars.append((grad, var))
                optimizer.apply_gradients(grads_and_vars)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            train_loss.update_state(loss)
            train_nll.update_state(negative_log_likelihood)
            train_accuracy.update_state(labels, logits)

        strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            images = tf.tile(images, [FLAGS.num_models, 1, 1, 1])
            logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)
            per_probs = tf.split(probs,
                                 num_or_size_splits=FLAGS.num_models,
                                 axis=0)
            for i in range(FLAGS.num_models):
                member_probs = per_probs[i]
                member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, member_probs)
                test_nlls[i].update_state(member_loss)
                test_accs[i].update_state(labels, member_probs)

            probs = tf.reduce_mean(per_probs, axis=0)

            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
            test_nll.update_state(negative_log_likelihood)
            test_accuracy.update_state(labels, probs)

        strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        with summary_writer.as_default():
            for step in range(steps_per_epoch):
                train_step(train_iterator)

                current_step = epoch * steps_per_epoch + (step + 1)
                max_steps = steps_per_epoch * FLAGS.train_epochs
                time_elapsed = time.time() - start_time
                steps_per_sec = float(current_step) / time_elapsed
                eta_seconds = (max_steps - current_step) / steps_per_sec
                message = (
                    '{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                    'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                        current_step / max_steps, epoch + 1,
                        FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                        time_elapsed / 60))
                if step % 20 == 0:
                    logging.info(message)

            tf.summary.scalar('train/loss',
                              train_loss.result(),
                              step=epoch + 1)
            tf.summary.scalar('train/negative_log_likelihood',
                              train_nll.result(),
                              step=epoch + 1)
            tf.summary.scalar('train/accuracy',
                              train_accuracy.result(),
                              step=epoch + 1)
            logging.info('Train Loss: %s, Accuracy: %s%%',
                         round(float(train_loss.result()), 4),
                         round(float(train_accuracy.result() * 100), 2))

            train_loss.reset_states()
            train_nll.reset_states()
            train_accuracy.reset_states()

            test_iterator = iter(test_dataset)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator)
            tf.summary.scalar('test/negative_log_likelihood',
                              test_nll.result(),
                              step=epoch + 1)
            tf.summary.scalar('test/accuracy',
                              test_accuracy.result(),
                              step=epoch + 1)
            logging.info('Test NLL: %s, Accuracy: %s%%',
                         round(float(test_nll.result()), 4),
                         round(float(test_accuracy.result() * 100), 2))

            test_nll.reset_states()
            test_accuracy.reset_states()

            for i in range(FLAGS.num_models):
                tf.summary.scalar('test/ensemble_nll_member{}'.format(i),
                                  test_nlls[i].result(),
                                  step=epoch + 1)
                tf.summary.scalar('test/ensemble_accuracy_member{}'.format(i),
                                  test_accs[i].result(),
                                  step=epoch + 1)
                logging.info('Member %d Test loss: %s, accuracy: %s%%', i,
                             round(float(test_nlls[i].result()), 4),
                             round(float(test_accs[i].result() * 100), 2))
                test_nlls[i].reset_states()
                test_accs[i].reset_states()

        if (epoch + 1) % 20 == 0:
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
Exemple #11
0
def main(unused_argv):
    tf.enable_v2_behavior()
    num_workers = 1
    job_name = 'worker'
    primary_cpu_task = '/job:%s' % job_name

    is_tpu_pod = num_workers > 1
    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    batch_size = PER_CORE_BATCH_SIZE * FLAGS.num_cores
    steps_per_epoch = FLAGS.steps_per_epoch or (int(
        APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
    steps_per_eval = int(1.0 *
                         math.ceil(IMAGENET_VALIDATION_IMAGES / batch_size))

    logging.info('Saving checkpoints at %s', model_dir)

    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        tpu=FLAGS.tpu, job_name=job_name)
    tf.config.experimental_connect_to_host(resolver.master())  # pylint: disable=line-too-long
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

    with tf.device(primary_cpu_task):
        # TODO(b/130307853): In TPU Pod, we have to use
        # `strategy.experimental_distribute_datasets_from_function` instead of
        # `strategy.experimental_distribute_dataset` because dataset cannot be
        # cloned in eager mode. And when using
        # `strategy.experimental_distribute_datasets_from_function`, we should use
        # per core batch size instead of global batch size, because no re-batch is
        # happening in this case.
        if is_tpu_pod:
            imagenet_train = imagenet_input.ImageNetInput(
                is_training=True,
                data_dir=FLAGS.data,
                batch_size=PER_CORE_BATCH_SIZE,
                use_bfloat16=_USE_BFLOAT16)
            imagenet_eval = imagenet_input.ImageNetInput(
                is_training=False,
                data_dir=FLAGS.data,
                batch_size=PER_CORE_BATCH_SIZE,
                use_bfloat16=_USE_BFLOAT16)
            train_dataset = strategy.experimental_distribute_datasets_from_function(
                imagenet_train.input_fn)
            test_dataset = strategy.experimental_distribute_datasets_from_function(
                imagenet_eval.input_fn)
        else:
            imagenet_train = imagenet_input.ImageNetInput(
                is_training=True,
                data_dir=FLAGS.data,
                batch_size=batch_size,
                use_bfloat16=_USE_BFLOAT16)
            imagenet_eval = imagenet_input.ImageNetInput(
                is_training=False,
                data_dir=FLAGS.data,
                batch_size=batch_size,
                use_bfloat16=_USE_BFLOAT16)
            train_dataset = strategy.experimental_distribute_dataset(
                imagenet_train.input_fn())
            test_dataset = strategy.experimental_distribute_dataset(
                imagenet_eval.input_fn())

        with strategy.scope():
            logging.info('Building Keras ResNet-50 model')
            model = resnet_model.ResNet50(num_classes=NUM_CLASSES)
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=_BASE_LEARNING_RATE, momentum=0.9, nesterov=True)
            training_loss = tf.keras.metrics.Mean('training_loss',
                                                  dtype=tf.float32)
            training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                'training_accuracy', dtype=tf.float32)
            test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
            test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                'test_accuracy', dtype=tf.float32)
            logging.info('Finished building Keras ResNet-50 model')

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(model_dir)
        initial_epoch = 0
        if latest_checkpoint:
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

        # Create summary writers
        train_summary_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'summaries/train'))
        test_summary_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'summaries/test'))

        @tf.function
        def train_step(iterator):
            """Training StepFn."""
            def step_fn(inputs):
                """Per-Replica StepFn."""
                images, labels = inputs
                with tf.GradientTape() as tape:
                    logits = model(images, training=True)

                    # Loss calculations.
                    #
                    # Part 1: Prediction loss.
                    prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits)
                    loss1 = tf.reduce_mean(prediction_loss)
                    # Part 2: Model weights regularization
                    loss2 = tf.reduce_sum(model.losses)

                    # Scale the loss given the TPUStrategy will reduce sum all gradients.
                    loss = loss1 + loss2
                    loss = loss / strategy.num_replicas_in_sync

                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))
                training_loss.update_state(loss)
                training_accuracy.update_state(labels, logits)

            strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        @tf.function
        def test_step(iterator):
            """Evaluation StepFn."""
            def step_fn(inputs):
                images, labels = inputs
                logits = model(images, training=False)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync
                test_loss.update_state(loss)
                test_accuracy.update_state(labels, logits)

            strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        train_iterator = iter(train_dataset)
        for epoch in range(initial_epoch, FLAGS.num_epochs):
            logging.info('Starting to run epoch: %s', epoch)
            with train_summary_writer.as_default():
                for step in range(steps_per_epoch):
                    learning_rate = compute_learning_rate(epoch + 1 +
                                                          (float(step) /
                                                           steps_per_epoch))
                    optimizer.lr = learning_rate
                    if step % 20 == 0:
                        logging.info(
                            'Learning rate at step %s in epoch %s is %s', step,
                            epoch, optimizer.lr.numpy())
                    train_step(train_iterator)
                tf.summary.scalar('loss',
                                  training_loss.result(),
                                  step=optimizer.iterations)
                tf.summary.scalar('accuracy',
                                  training_accuracy.result(),
                                  step=optimizer.iterations)
                logging.info('Training loss: %s, accuracy: %s%%',
                             round(training_loss.result(), 4),
                             round(training_accuracy.result() * 100, 2))
                training_loss.reset_states()
                training_accuracy.reset_states()

            with test_summary_writer.as_default():
                test_iterator = iter(test_dataset)
                for step in range(steps_per_eval):
                    if step % 20 == 0:
                        logging.info(
                            'Starting to run eval step %s of epoch: %s', step,
                            epoch)
                    test_step(test_iterator)
                tf.summary.scalar('loss',
                                  test_loss.result(),
                                  step=optimizer.iterations)
                tf.summary.scalar('accuracy',
                                  test_accuracy.result(),
                                  step=optimizer.iterations)
                logging.info('Test loss: %s, accuracy: %s%%',
                             round(test_loss.result(), 4),
                             round(test_accuracy.result() * 100, 2))
                test_loss.reset_states()
                test_accuracy.reset_states()

            checkpoint_name = checkpoint.save(
                os.path.join(model_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
Exemple #12
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  filename = os.path.join(FLAGS.expert_dir, FLAGS.env_name + '.npz')
  (expert_states, expert_actions, expert_next_states,
   expert_dones) = data_utils.load_expert_data(filename)

  (expert_states, expert_actions, expert_next_states,
   expert_dones) = data_utils.subsample_trajectories(expert_states,
                                                     expert_actions,
                                                     expert_next_states,
                                                     expert_dones,
                                                     FLAGS.num_trajectories)
  print('# of demonstraions: {}'.format(expert_states.shape[0]))

  if FLAGS.normalize_states:
    shift = -np.mean(expert_states, 0)
    scale = 1.0 / (np.std(expert_states, 0) + 1e-3)
    expert_states = (expert_states + shift) * scale
    expert_next_states = (expert_next_states + shift) * scale
  else:
    shift = None
    scale = None

  env = wrappers.create_il_env(FLAGS.env_name, FLAGS.seed, shift, scale)

  eval_env = wrappers.create_il_env(FLAGS.env_name, FLAGS.seed + 1, shift,
                                    scale)

  unwrap_env = env

  while hasattr(unwrap_env, 'env'):
    if isinstance(unwrap_env, wrappers.NormalizeBoxActionWrapper):
      expert_actions = unwrap_env.reverse_action(expert_actions)
      break
    unwrap_env = unwrap_env.env

  (expert_states, expert_actions, expert_next_states,
   expert_dones) = data_utils.add_absorbing_states(expert_states,
                                                   expert_actions,
                                                   expert_next_states,
                                                   expert_dones, env)

  spec = (
      tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32,
                             'observation'),
      tensor_spec.TensorSpec([env.action_space.shape[0]], tf.float32, 'action'),
      tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32,
                             'next_observation'),
      tensor_spec.TensorSpec([1], tf.float32, 'reward'),
      tensor_spec.TensorSpec([1], tf.float32, 'mask'),
  )

  # We need to store at most twice more transition due to
  # an extra absorbing to itself transition.
  replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      spec, batch_size=1, max_length=FLAGS.max_timesteps * 2)

  for i in range(expert_states.shape[0]):
    # Overwrite rewards for safety. We still have to add them to the replay
    # buffer to maintain the same interface. Also always use a zero mask
    # since we need to always bootstrap for imitation learning.
    add_samples_to_replay_buffer(replay_buffer, expert_states[i],
                                 expert_actions[i], expert_next_states[i])

  replay_buffer_iter = iter(
      replay_buffer.as_dataset(sample_batch_size=FLAGS.sample_batch_size))

  policy_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      spec, batch_size=1, max_length=FLAGS.max_timesteps * 2)

  policy_replay_buffer_iter = iter(
      policy_replay_buffer.as_dataset(
          sample_batch_size=FLAGS.sample_batch_size))

  expert_states = tf.Variable(expert_states, dtype=tf.float32)
  expert_actions = tf.Variable(expert_actions, dtype=tf.float32)
  expert_next_states = tf.Variable(expert_next_states, dtype=tf.float32)
  expert_dones = tf.Variable(expert_dones, dtype=tf.float32)

  expert_dataset = tf.data.Dataset.from_tensor_slices(
      (expert_states, expert_actions, expert_next_states))
  expert_dataset = expert_dataset.repeat().shuffle(
      expert_states.shape[0]).batch(
          FLAGS.sample_batch_size, drop_remainder=True)

  expert_dataset_iter = iter(expert_dataset)

  hparam_str_dict = dict(
      seed=FLAGS.seed, algo=FLAGS.algo, env_name=FLAGS.env_name)
  hparam_str = ','.join(['%s=%s' % (k, str(hparam_str_dict[k])) for k in
                         sorted(hparam_str_dict.keys())])

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  log_dir = os.path.join(FLAGS.save_dir, 'logs')
  log_filename = os.path.join(log_dir, hparam_str)
  if not os.path.exists(log_dir):
    os.makedirs(log_dir)

  if 'dac' in FLAGS.algo:
    imitator = gail.RatioGANGP(env.observation_space.shape[0],
                               env.action_space.shape[0], FLAGS.log_interval)
  elif 'value_dice' in FLAGS.algo:
    imitator = value_dice.ValueDICE(
        env.observation_space.shape[0],
        env.action_space.shape[0],
        nu_lr=FLAGS.nu_lr,
        actor_lr=FLAGS.actor_lr,
        alpha_init=FLAGS.sac_alpha,
        hidden_size=FLAGS.hidden_size,
        log_interval=FLAGS.log_interval)

  def get_imitation_learning_rewards(states, actions, _):
    return imitator.get_log_occupancy_ratio(states, actions)

  if 'value_dice' in FLAGS.algo:
    sac = imitator
  else:
    sac = twin_sac.SAC(
        env.observation_space.shape[0],
        env.action_space.shape[0],
        FLAGS.log_interval,
        actor_lr=FLAGS.actor_lr,
        critic_lr=FLAGS.critic_lr,
        learn_alpha=FLAGS.learn_alpha,
        alpha_init=FLAGS.sac_alpha,
        rewards_fn=get_imitation_learning_rewards)

  episode_return = 0
  episode_timesteps = 0
  done = True

  total_timesteps = 0
  previous_time = time.time()

  eval_returns = []
  with tqdm(total=FLAGS.max_timesteps, desc='') as pbar:
    while total_timesteps < FLAGS.max_timesteps:
      _update_pbar_msg(pbar, total_timesteps)

      if total_timesteps % FLAGS.eval_interval == 0:
        logging.info('Performing policy eval.')
        average_returns, evaluation_timesteps = evaluate(sac.actor, eval_env)

        eval_returns.append(average_returns)
        np.save(log_filename, np.array(eval_returns))

        with summary_writer.as_default():
          tf.summary.scalar(
              'eval gym/average returns', average_returns, step=total_timesteps)
        with summary_writer.as_default():
          tf.summary.scalar(
              'eval gym/average episode length',
              evaluation_timesteps,
              step=total_timesteps)
        logging.info('Eval: ave returns=%f, ave episode length=%f',
                     average_returns, evaluation_timesteps)

      if done:
        if episode_timesteps > 0:
          current_time = time.time()
          with summary_writer.as_default():
            tf.summary.scalar(
                'train gym/returns', episode_return, step=total_timesteps)
            tf.summary.scalar(
                'train gym/FPS',
                episode_timesteps / (current_time - previous_time),
                step=total_timesteps)

        obs = env.reset()
        episode_return = 0
        episode_timesteps = 0
        previous_time = time.time()

      if total_timesteps < FLAGS.num_random_actions:
        action = env.action_space.sample()
      else:
        if 'dac' in FLAGS.algo:
          _, sampled_action, _ = sac.actor(np.array([obs]))
          action = sampled_action[0].numpy()
        else:
          mean_action, _, _ = sac.actor(np.array([obs]))
          action = mean_action[0].numpy()
          action = (action + np.random.normal(
              0, 0.1, size=action.shape)).clip(-1, 1)

      next_obs, reward, done, _ = env.step(action)

      # done caused by episode truncation.
      truncated_done = done and episode_timesteps + 1 == env._max_episode_steps  # pylint: disable=protected-access

      if done and not truncated_done:
        next_obs = env.get_absorbing_state()

      # Overwrite rewards for safety. We still have to add them to the replay
      # buffer to maintain the same interface. Also always use a zero mask
      # since we need to always bootstrap for imitation learning.
      add_samples_to_replay_buffer(replay_buffer, obs, action, next_obs)

      add_samples_to_replay_buffer(policy_replay_buffer, obs, action, next_obs)
      if done and not truncated_done:
        # Add several absobrsing states to absorbing states transitions.
        for abs_i in range(FLAGS.absorbing_per_episode):
          if abs_i + episode_timesteps < env._max_episode_steps:  # pylint: disable=protected-access
            obs = env.get_absorbing_state()
            action = env.action_space.sample()
            next_obs = env.get_absorbing_state()

            add_samples_to_replay_buffer(replay_buffer, obs, action, next_obs)
            add_samples_to_replay_buffer(policy_replay_buffer, obs, action,
                                         next_obs)

      episode_return += reward
      episode_timesteps += 1
      total_timesteps += 1
      pbar.update(1)

      obs = next_obs

      if total_timesteps >= FLAGS.start_training_timesteps:
        with summary_writer.as_default():
          for _ in range(FLAGS.updates_per_step):
            if 'dac' in FLAGS.algo:
              imitator.update(expert_dataset_iter, policy_replay_buffer_iter)
            elif 'value_dice' in FLAGS.algo:
              imitator.update(
                  expert_dataset_iter,
                  policy_replay_buffer_iter,
                  FLAGS.discount,
                  replay_regularization=FLAGS.replay_regularization)

            if 'bc' in FLAGS.algo:
              sac.train_bc(expert_dataset_iter)
            elif 'dac' in FLAGS.algo:
              sac.train(
                  replay_buffer_iter,
                  discount=FLAGS.discount,
                  tau=FLAGS.tau,
                  target_entropy=-env.action_space.shape[0],
                  actor_update_freq=FLAGS.actor_update_freq)
def main(argv):
    global BLEU_THRESHOLD_REACHED
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    init_mllogger()
    mllogger.event('cache_clear')
    mllogger.start('init_start')
    mllogger.event('submission_org', 'Google')
    mllogger.event('submission_platform',
                   'TPUv3-{}'.format(jax.device_count()))
    mllogger.event('submission_division', 'closed')
    mllogger.event('submission_status', 'research')
    mllogger.event('submission_benchmark', 'transformer')
    mllogger.event('train_samples', input_pipeline.N_TRAIN)
    mllogger.event('eval_samples', input_pipeline.N_EVAL)

    tf.enable_v2_behavior()

    # Use hardware RNG for bernoulli randoms in dropout mask creation.
    if FLAGS.hardware_rng:
        models.set_hardware_bernoulli()

    num_partitions = FLAGS.num_partitions
    batch_size = FLAGS.batch_size
    if batch_size is None:
        batch_size = min(16 * jax.device_count() // num_partitions, 2048)
    mllogger.event('global_batch_size', batch_size)

    num_eval_steps = FLAGS.num_eval_steps
    max_target_length = FLAGS.max_target_length
    max_eval_target_length = FLAGS.max_eval_target_length
    max_length = max(max_target_length, max_eval_target_length)
    mllogger.event('max_sequence_length',
                   max_length,
                   metadata={'method': 'discard'})
    if FLAGS.random_seed is not None:
        seed = FLAGS.random_seed
    else:
        seed = np.int32(time.time() if jax.host_id() == 0 else 0)
        seed = per_host_sum_pmap(seed)
    mllogger.event('seed', int(seed))
    steps_per_epoch = int(math.ceil(input_pipeline.N_TRAIN / batch_size))
    logging.info('steps per epoch: %d', steps_per_epoch)
    num_replicas = jax.local_device_count() // num_partitions
    device_train_input_shape = (batch_size //
                                (num_replicas * jax.host_count()),
                                max_target_length)
    # This is per-host; in principle 64/replica or more should fit
    eval_batch_size = min(
        32 * num_replicas,
        int(
            math.ceil(input_pipeline.N_EVAL /
                      (num_replicas * jax.host_count()))) * num_replicas)
    logging.info('eval batch size: %d', eval_batch_size)
    pred_batches = int(
        math.ceil(input_pipeline.N_EVAL /
                  (jax.host_count() * eval_batch_size)))
    logging.info('pred batches: %d', pred_batches)
    broadcast = functools.partial(_broadcast,
                                  num_replicas=num_replicas,
                                  num_partitions=num_partitions)

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))
    else:
        train_summary_writer = None
        eval_summary_writer = None
    # Write summaries in background thread to avoid blocking on device sync
    summary_thread = thread.ThreadPoolExecutor(1, 'summary')
    if FLAGS.infeed:
        # Infeed is currently synchronous, so do it in a background thread too
        infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(),
                                                'infeed')

    def maybe_start_xprof(seconds):
        if jax.host_id() == 0 and FLAGS.xprof:
            xprof = xprof_session.XprofSession()
            xprof.start_session('REDACTED', True, 2)

            def sleep_and_end_xprof():
                time.sleep(seconds)
                logging.info(
                    'Xprof URL: %s',
                    xprof.end_session_and_get_url(
                        tag=
                        'flax transformer, {} devices, {}-way, batch {} per replica'
                        .format(jax.device_count(), num_partitions,
                                device_train_input_shape[0])))

            thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof)

    # MLPerf 2020 WMT en-de dataset uses a custom T2T dataset:
    #   Shared 32K subword tokenization
    #   256-length packed training examples from WMT17
    #   97-length unpacked evaluation examples from WMT14
    train_keys = [
        'inputs', 'targets', 'inputs_position', 'targets_position',
        'inputs_segmentation', 'targets_segmentation'
    ]
    encoder = mlperf_encoder.SubwordTextEncoder(filename=FLAGS.vocab_path)
    input_encoder = encoder
    target_encoder = encoder
    vocab_size = input_encoder.vocab_size
    output_vocab_size = target_encoder.vocab_size

    input_shape = (batch_size, max_target_length)
    target_shape = (batch_size, max_target_length)

    transformer_kwargs = {
        'vocab_size': vocab_size,
        'output_vocab_size': output_vocab_size,
        'emb_dim': 1024,
        'num_heads': 16,
        'num_layers': 6,
        'qkv_dim': 1024,
        'mlp_dim': 4096,
        'max_len': max_length,
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
        'num_partitions': num_partitions,
    }

    rng = random.PRNGKey(seed)
    rng, init_rng = random.split(rng)
    model, cache_def = create_model(init_rng, tuple(input_shape),
                                    tuple(target_shape), transformer_kwargs)
    mllogger.event('opt_name', 'adam')
    if batch_size < 1024:
        learning_rate = 4.0  # 0.0625
        warmup_steps = 1000
        beta1 = 0.9
        beta2 = 0.98
    if batch_size < 2048:
        learning_rate = 2.0
        warmup_steps = 500  # ??
        beta1 = 0.9  # ??
        beta2 = 0.98  # ??
    else:
        learning_rate = 3.3092157691415953
        warmup_steps = 664
        beta1 = 0.9086575725261137
        beta2 = 0.9198719118104947
    epsilon = 1e-9
    if FLAGS.learning_rate is not None:
        learning_rate = FLAGS.learning_rate
    mllogger.event('opt_adam_beta_1', beta1)
    mllogger.event('opt_adam_beta_2', beta2)
    mllogger.event('opt_adam_epsilon', epsilon)
    optimizer_def = optim.Adam(learning_rate,
                               beta1=beta1,
                               beta2=beta2,
                               eps=epsilon,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(model)
    del model  # don't keep a copy of the initial model

    # Build parameter partition annotations for preserving partitions from train
    # to eval.
    partition_rules = [
        (('encoder', 'posembed_input'), partitions.empty_dict),
        (('decoder', 'posembed_targets'), partitions.empty_dict),
        (('embedding', ), partitions.spec(num_partitions, 1)),
        ((r'LayerNorm_\d+', '(bias|scale)'), None),
        ((r'encoder(decoder)?_norm', '(bias|scale)'), None),
        ((r'MultiHeadDotProductAttention_\d+', '(query|key|value)', 'kernel'),
         partitions.spec(1, num_partitions, 1)),
        ((r'MultiHeadDotProductAttention_\d+', 'out', 'kernel'),
         partitions.spec(num_partitions, 1, 1)),
        ((r'MlpBlock_\d+', r'Dense_\d+', 'bias'), None),
        ((r'MlpBlock_\d+', 'Dense_0', 'kernel'),
         partitions.spec(1, num_partitions)),
        ((r'MlpBlock_\d+', 'Dense_1', 'kernel'),
         partitions.spec(num_partitions, 1)),
        (('state', 'step'), None),
    ]
    optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(partition_rules, optimizer.state_dict()))

    optimizer = broadcast(optimizer)
    empty_metrics = broadcast({'loss': 0.0, 'accuracy': 0, 'denominator': 0})

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=learning_rate,
        warmup_steps=warmup_steps,
        hidden_size=transformer_kwargs['qkv_dim'])

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch',
                            in_axes=(None, 0, 0, 0))
    if num_partitions > 1:
        sharded_predict_step = sharded_jit(
            predict_step,
            in_parts=(None, optimizer_partitions.target, None),
            out_parts=None)
    else:
        sharded_predict_step = predict_step
    if FLAGS.extra_eval_metrics:
        p_eval_step = jax.pmap(eval_step, axis_name='batch', in_axes=(None, 0))
    p_pred_step = jax.pmap(sharded_predict_step,
                           axis_name='batch',
                           in_axes=(0, None, None))
    p_allreduce_metrics = jax.pmap(functools.partial(lax.psum,
                                                     axis_name='batch'),
                                   axis_name='batch')

    def device_train_loop_cond(args):
        _, _, _, _, step, epoch = args
        return step // steps_per_epoch == epoch

    def device_train_loop_body(args):
        optimizer, dropout_rngs, metrics, token, step, epoch = args
        input_data, token = lax.infeed(token,
                                       shape=tuple([
                                           jax.ShapedArray(
                                               device_train_input_shape,
                                               jnp.int32) for _ in train_keys
                                       ]))
        batch = {k: v for k, v in zip(train_keys, input_data)}
        optimizer, metrics, dropout_rngs = train_step(optimizer,
                                                      batch,
                                                      metrics,
                                                      learning_rate_fn,
                                                      dropout_rng=dropout_rngs)
        step += 1
        return optimizer, dropout_rngs, metrics, token, step, epoch

    def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch):
        token = lax.create_token(step)
        optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, dropout_rngs, metrics, token, step, epoch))
        return optimizer, dropout_rngs, metrics, step

    if num_partitions > 1:
        device_train_loop = sharded_jit(device_train_loop,
                                        in_parts=(optimizer_partitions, None,
                                                  None, None, None),
                                        out_parts=(optimizer_partitions, None,
                                                   None, None))
    p_train_epoch = jax.pmap(device_train_loop,
                             axis_name='batch',
                             in_axes=(None, 0, 0, None, None))

    p_allreduce_metrics_train = functools.partial(lax.psum, axis_name='batch')
    if num_partitions > 1:
        p_allreduce_metrics_train = sharded_jit(p_allreduce_metrics_train,
                                                in_parts=None,
                                                out_parts=None,
                                                num_partitions=num_partitions)
    p_allreduce_metrics_train = jax.pmap(p_allreduce_metrics_train,
                                         axis_name='batch')

    # Precompile all needed computations with fake data so as not to include
    # compilation time in MLPerf metrics.
    if FLAGS.precompile:
        logging.info('precompiling step/epoch functions')
        if FLAGS.infeed:
            # the device training loop condition will immediately be false, but
            # the optimizer tree will be resharded here
            optimizer, *_ = p_train_epoch(unbroadcast(optimizer),
                                          random.split(rng, num_replicas),
                                          empty_metrics,
                                          jnp.array(0, dtype=jnp.int32), 1)
        else:
            metrics = empty_metrics
            train_input_shape = (num_replicas, batch_size // num_replicas,
                                 input_pipeline.MAX_TRAIN_LEN)
            fake_batch = {
                k: jnp.ones(train_input_shape, jnp.int32)
                for k in train_keys
            }
            p_train_step(unbroadcast(optimizer),
                         fake_batch,
                         metrics,
                         dropout_rng=random.split(rng, num_replicas))
        eval_input_shape = (num_replicas, eval_batch_size // num_replicas,
                            input_pipeline.MAX_EVAL_LEN)
        fake_eval_batch = {
            'inputs': jnp.ones(eval_input_shape, jnp.int32),
            'targets': jnp.ones(eval_input_shape, jnp.int32),
        }
        if FLAGS.extra_eval_metrics:
            p_eval_step(unbroadcast(optimizer.target), fake_eval_batch)
        fake_cache = cache_def.initialize_cache(
            (eval_input_shape[1], FLAGS.max_predict_length))
        maybe_start_xprof(20)
        p_pred_step(fake_eval_batch['inputs'], unbroadcast(optimizer.target),
                    fake_cache)
        time.sleep(20)
        sync_devices()
        fake_bleu_1 = np.zeros((4, ), dtype=np.int32)
        fake_bleu_2 = np.zeros((), dtype=np.int32)
        per_host_sum_pmap((fake_bleu_1, fake_bleu_1, fake_bleu_2, fake_bleu_2))
        sync_devices()
        p_allreduce_metrics_train(empty_metrics)
        sync_devices()
        logging.info('finished precompiling step/epoch functions')

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, num_replicas)

    # Record time-0 metrics for proper tensorboard plot x-axis scaling.
    if jax.host_id() == 0:
        if FLAGS.compute_train_metrics:
            train_summary_writer.scalar('loss', 9.999, 0)
            train_summary_writer.scalar('accuracy', 0.0, 0)
            train_summary_writer.flush()
        eval_summary_writer.scalar('bleu', 0.0, 0)
        eval_summary_writer.flush()

    train_ds = input_pipeline.get_wmt_dataset(batch_size=batch_size //
                                              jax.host_count(),
                                              train=True)
    eval_ds = input_pipeline.get_wmt_dataset(batch_size=eval_batch_size,
                                             train=False)
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)
    local_devices = jax.local_devices()
    maybe_start_xprof(max(30, 60 / (jax.device_count() / 2048)))
    host_step, device_step = 0, broadcast(0)
    gc.disable()
    mllogger.end('init_stop')
    if jax.host_id() == 0:
        mllogger.start('run_start')
    for epoch in range(FLAGS.num_epochs):
        if jax.host_id() == 0 and not BLEU_THRESHOLD_REACHED:
            mllogger.start('block_start',
                           metadata={
                               'first_epoch_num': epoch + 1,
                               'epoch_count': 1
                           })
        metrics = empty_metrics
        if FLAGS.infeed:
            optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
                unbroadcast(optimizer), dropout_rngs, metrics,
                unbroadcast(device_step), epoch)
        while int(host_step // steps_per_epoch) == epoch:
            # pylint: disable=protected-access
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))
            # Shard data to devices and do a training step.
            batch = jax.tree_map(
                lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), batch)
            if FLAGS.infeed:
                for i, device in enumerate(local_devices):
                    replica_id = i // num_partitions
                    input_tuple = tuple(
                        [batch[k][replica_id] for k in train_keys])
                    assert input_tuple[0].shape == device_train_input_shape, (
                        'infeed shape error %s != %s' %
                        (input_tuple[0].shape, device_train_input_shape))
                    assert input_tuple[0].dtype == jnp.int32, (
                        'infeed dtype error %s != %s' %
                        (input_tuple[0].dtype, jnp.int32))
                    infeed_pool.submit(
                        functools.partial(device.transfer_to_infeed,
                                          input_tuple))
            else:
                optimizer, metrics, dropout_rngs = p_train_step(
                    unbroadcast(optimizer),
                    batch,
                    metrics,
                    dropout_rng=dropout_rngs)
            host_step += 1

        if FLAGS.compute_train_metrics:
            metrics = p_allreduce_metrics_train(metrics)
            # Schedule training metric handling.
            summary_thread.submit(
                functools.partial(write_train_summary, metrics,
                                  train_summary_writer, host_step))

        # Optional, extra evaluation metrics.
        if FLAGS.extra_eval_metrics:
            eval_metrics = []
            eval_iter = iter(eval_ds)
            for _, eval_batch in zip(range(num_eval_steps), eval_iter):
                eval_batch = common_utils.shard(eval_batch)
                metrics = p_eval_step(unbroadcast(optimizer.target),
                                      eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = p_allreduce_metrics(eval_metrics)
            # Schedule metric summarization/logging.
            summary_thread.submit(
                functools.partial(write_eval_summary, eval_metrics,
                                  eval_summary_writer, host_step))

        # Translation and BLEU Score.
        all_predicted, all_targets, all_bs = [], [], []
        for i in range(pred_batches):
            # pylint: disable=protected-access
            pred_batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter))
            logging.info('Predicting on input of shape %s.',
                         str(pred_batch['inputs'].shape))
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size != eval_batch_size:
                logging.info('Translation: uneven batch size %d.',
                             cur_pred_batch_size)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, eval_batch_size), pred_batch)
            pred_batch = jax.tree_map(
                lambda x: x.reshape((num_replicas, -1) + x.shape[1:]),
                pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache = cache_def.initialize_cache(
                (per_device_batchsize, FLAGS.max_predict_length))
            all_predicted.append(
                p_pred_step(pred_batch['inputs'],
                            unbroadcast(optimizer.target), cache))
            all_targets.append(pred_batch['targets'])
            all_bs.append(cur_pred_batch_size)
        # Schedule BLEU calculation and summarization/logging.
        # We use the ICI as part of BLEU score computation, so we call this from the
        # main thread so the BLEU pmap runs before the next train epoch pmap
        write_predict_summary(all_predicted, all_targets, all_bs,
                              target_encoder, eval_summary_writer, epoch,
                              host_step, summary_thread)

    # Wait until computations are done before exiting
    sync_devices()
    if jax.host_id() == 0:
        summary_thread.shutdown()
        if not BLEU_THRESHOLD_REACHED:
            mllogger.end('run_stop', metadata={'status': 'aborted'})
Exemple #14
0
def main(_):

  tf.enable_v2_behavior()
  # make sure tf does not allocate gpu memory
  tf.config.experimental.set_visible_devices([], 'GPU')

  # Performance gains on TPU by switching to hardware bernoulli.
  def hardware_bernoulli(rng_key, p=jax.numpy.float32(0.5), shape=None):
    lax_key = jax.lax.tie_in(rng_key, 0.0)
    return jax.lax.rng_uniform(lax_key, 1.0, shape) < p

  def set_hardware_bernoulli():
    jax.random.bernoulli = hardware_bernoulli

  set_hardware_bernoulli()

  # As we gridsearch the weight decay and the learning rate, we add them to the
  # output directory path so that each model has its own directory to save the
  # results in. We also add the `run_seed` which is "gridsearched" on to
  # replicate an experiment several times.
  output_dir_suffix = os.path.join(
      'lr_' + str(FLAGS.learning_rate),
      'wd_' + str(FLAGS.weight_decay),
      'rho_' + str(FLAGS.sam_rho),
      'seed_' + str(FLAGS.run_seed))

  output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix)

  if not gfile.exists(output_dir):
    gfile.makedirs(output_dir)

  num_devices = jax.local_device_count() * jax.host_count()
  assert FLAGS.batch_size % num_devices == 0
  local_batch_size = FLAGS.batch_size // num_devices
  info = 'Total batch size: {} ({} x {} replicas)'.format(
      FLAGS.batch_size, local_batch_size, num_devices)
  logging.info(info)

  if FLAGS.dataset == 'cifar10':
    if FLAGS.from_pretrained_checkpoint:
      image_size = efficientnet.name_to_image_size(FLAGS.model_name)
    else:
      image_size = None
    dataset_source = dataset_source_lib.Cifar10(
        FLAGS.batch_size // jax.host_count(),
        FLAGS.image_level_augmentations,
        FLAGS.batch_level_augmentations,
        image_size=image_size)
  elif FLAGS.dataset == 'cifar100':
    if FLAGS.from_pretrained_checkpoint:
      image_size = efficientnet.name_to_image_size(FLAGS.model_name)
    else:
      image_size = None
    dataset_source = dataset_source_lib.Cifar100(
        FLAGS.batch_size // jax.host_count(), FLAGS.image_level_augmentations,
        FLAGS.batch_level_augmentations, image_size=image_size)

  elif FLAGS.dataset == 'fashion_mnist':
    dataset_source = dataset_source_lib.FashionMnist(
        FLAGS.batch_size, FLAGS.image_level_augmentations,
        FLAGS.batch_level_augmentations)
  elif FLAGS.dataset == 'svhn':
    dataset_source = dataset_source_lib.SVHN(
        FLAGS.batch_size, FLAGS.image_level_augmentations,
        FLAGS.batch_level_augmentations)
  elif FLAGS.dataset == 'imagenet':
    imagenet_image_size = efficientnet.name_to_image_size(FLAGS.model_name)
    dataset_source = dataset_source_imagenet.Imagenet(
        FLAGS.batch_size // jax.host_count(), imagenet_image_size,
        FLAGS.image_level_augmentations)
  else:
    raise ValueError('Dataset not recognized.')

  if 'cifar' in FLAGS.dataset or 'svhn' in FLAGS.dataset:
    if image_size is None or 'svhn' in FLAGS.dataset:
      image_size = 32
    num_channels = 3
    num_classes = 100 if FLAGS.dataset == 'cifar100' else 10
  elif FLAGS.dataset == 'fashion_mnist':
    image_size = 28  # For Fashion Mnist
    num_channels = 1
    num_classes = 10
  elif FLAGS.dataset == 'imagenet':
    image_size = imagenet_image_size
    num_channels = 3
    num_classes = 1000
  else:
    raise ValueError('Dataset not recognized.')

  try:
    model, state = load_imagenet_model.get_model(FLAGS.model_name,
                                                 local_batch_size, image_size,
                                                 num_classes)
  except load_imagenet_model.ModelNameError:
    model, state = load_model.get_model(FLAGS.model_name,
                                        local_batch_size, image_size,
                                        num_classes, num_channels)

  # Learning rate will be overwritten by the lr schedule, we set it to zero.
  optimizer = flax_training.create_optimizer(model, 0.0)

  flax_training.train(optimizer, state, dataset_source, output_dir,
                      FLAGS.num_epochs)
Exemple #15
0
def main(argv):
    del argv  # Unused arg.

    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.version2:
        per_core_bs_train = FLAGS.per_core_batch_size // (
            FLAGS.ensemble_size * FLAGS.num_train_samples)
        per_core_bs_eval = FLAGS.per_core_batch_size // (
            FLAGS.ensemble_size * FLAGS.num_eval_samples)
    else:
        per_core_bs_train = FLAGS.per_core_batch_size // FLAGS.num_train_samples
        per_core_bs_eval = FLAGS.per_core_batch_size // FLAGS.num_eval_samples
    batch_size_train = per_core_bs_train * FLAGS.num_cores
    batch_size_eval = per_core_bs_eval * FLAGS.num_cores

    logging.info('Saving checkpoints at %s', FLAGS.output_dir)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    train_input_fn = utils.load_input_fn(split=tfds.Split.TRAIN,
                                         name=FLAGS.dataset,
                                         batch_size=per_core_bs_train,
                                         use_bfloat16=FLAGS.use_bfloat16,
                                         normalize=False)
    clean_test_input_fn = utils.load_input_fn(split=tfds.Split.TEST,
                                              name=FLAGS.dataset,
                                              batch_size=per_core_bs_eval,
                                              use_bfloat16=FLAGS.use_bfloat16,
                                              normalize=False)
    train_dataset = strategy.experimental_distribute_datasets_from_function(
        train_input_fn)
    test_datasets = {
        'clean':
        strategy.experimental_distribute_datasets_from_function(
            clean_test_input_fn),
    }
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar10':
            load_c_input_fn = utils.load_cifar10_c_input_fn
        else:
            load_c_input_fn = functools.partial(utils.load_cifar100_c_input_fn,
                                                path=FLAGS.cifar100_c_path)
        corruption_types, max_intensity = utils.load_corrupted_test_info(
            FLAGS.dataset)
        for corruption in corruption_types:
            for intensity in range(1, max_intensity + 1):
                input_fn = load_c_input_fn(corruption_name=corruption,
                                           corruption_intensity=intensity,
                                           batch_size=per_core_bs_eval,
                                           use_bfloat16=FLAGS.use_bfloat16,
                                           normalize=False)
                test_datasets['{0}_{1}'.format(corruption, intensity)] = (
                    strategy.experimental_distribute_datasets_from_function(
                        input_fn))

    ds_info = tfds.builder(FLAGS.dataset).info
    train_dataset_size = ds_info.splits['train'].num_examples
    test_dataset_size = ds_info.splits['test'].num_examples
    num_classes = ds_info.features['label'].num_classes

    steps_per_epoch = train_dataset_size // batch_size_train
    steps_per_eval = test_dataset_size // batch_size_eval

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras ResNet-32 model')
        model = resnet_cifar_model.rank1_resnet_v1(
            input_shape=ds_info.features['image'].shape,
            depth=32,
            num_classes=num_classes,
            width_multiplier=4,
            alpha_initializer=FLAGS.alpha_initializer,
            gamma_initializer=FLAGS.gamma_initializer,
            alpha_regularizer=FLAGS.alpha_regularizer,
            gamma_regularizer=FLAGS.gamma_regularizer,
            use_additive_perturbation=FLAGS.use_additive_perturbation,
            ensemble_size=FLAGS.ensemble_size,
            random_sign_init=FLAGS.random_sign_init,
            dropout_rate=FLAGS.dropout_rate)
        logging.info(model.summary())
        base_lr = FLAGS.base_learning_rate * batch_size_train / 128
        lr_decay_epochs = [(start_epoch * FLAGS.train_epochs) // 200
                           for start_epoch in FLAGS.lr_decay_epochs]
        lr_schedule = utils.LearningRateSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/loss':
            tf.keras.metrics.Mean(),
        }
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, max_intensity + 1):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        ed.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))

        test_diversity = {}
        training_diversity = {}
        if FLAGS.ensemble_size > 1:
            for i in range(FLAGS.ensemble_size):
                metrics['test/nll_member_{}'.format(
                    i)] = tf.keras.metrics.Mean()
                metrics['test/accuracy_member_{}'.format(i)] = (
                    tf.keras.metrics.SparseCategoricalAccuracy())
            test_diversity = {
                'test/disagreement': tf.keras.metrics.Mean(),
                'test/average_kl': tf.keras.metrics.Mean(),
                'test/cosine_similarity': tf.keras.metrics.Mean(),
            }
            training_diversity = {
                'train/disagreement': tf.keras.metrics.Mean(),
                'train/average_kl': tf.keras.metrics.Mean(),
                'train/cosine_similarity': tf.keras.metrics.Mean(),
            }

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            if FLAGS.version2 and FLAGS.ensemble_size > 1:
                images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
                if not (FLAGS.member_sampling or FLAGS.expected_probs):
                    labels = tf.tile(labels, [FLAGS.ensemble_size])

            if FLAGS.num_train_samples > 1:
                images = tf.tile(images, [FLAGS.num_train_samples, 1, 1, 1])

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                probs = tf.nn.softmax(logits)
                # Diversity evaluation.
                if FLAGS.version2 and FLAGS.ensemble_size > 1:
                    per_probs = tf.reshape(
                        probs,
                        tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]],
                                  0))

                    diversity_results = ed.metrics.average_pairwise_diversity(
                        per_probs, FLAGS.ensemble_size)

                if FLAGS.num_train_samples > 1:
                    probs = tf.reshape(
                        probs,
                        tf.concat(
                            [[FLAGS.num_train_samples, -1], probs.shape[1:]],
                            0))
                    probs = tf.reduce_mean(probs, 0)

                if FLAGS.member_sampling and FLAGS.version2 and FLAGS.ensemble_size > 1:
                    idx = tf.random.uniform([],
                                            maxval=FLAGS.ensemble_size,
                                            dtype=tf.int64)
                    idx_one_hot = tf.expand_dims(
                        tf.one_hot(idx, FLAGS.ensemble_size,
                                   dtype=probs.dtype), 0)
                    probs_shape = probs.shape
                    probs = tf.reshape(probs, [FLAGS.ensemble_size, -1])
                    probs = tf.matmul(idx_one_hot, probs)
                    probs = tf.reshape(probs,
                                       tf.concat([[-1], probs_shape[1:]], 0))

                elif FLAGS.expected_probs and FLAGS.version2 and FLAGS.ensemble_size > 1:
                    probs = tf.reshape(
                        probs,
                        tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]],
                                  0))
                    probs = tf.reduce_mean(probs, 0)

                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, probs))

                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the slow weights and bias terms. This excludes BN
                    # parameters and fast weight approximate posterior/prior parameters,
                    # but pay caution to their naming scheme.
                    if 'kernel' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                kl = sum(model.losses) / train_dataset_size
                kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
                kl_scale /= FLAGS.kl_annealing_steps
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss + kl_loss
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate implementation.
            grad_list = []
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = list(zip(grads, model.trainable_variables))
                for vec, var in grads_and_vars:
                    # Apply different learning rate on the fast weight approximate
                    # posterior/prior parameters. This is excludes BN and slow weights,
                    # but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grad_list.append(
                            (vec * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grad_list.append((vec, var))
                optimizer.apply_gradients(grad_list)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            metrics['train/ece'].update_state(labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            if FLAGS.version2 and FLAGS.ensemble_size > 1:
                for k, v in diversity_results.items():
                    training_diversity['train/' + k].update_state(v)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            if FLAGS.ensemble_size > 1:
                images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
            if FLAGS.num_eval_samples > 1:
                images = tf.tile(images, [FLAGS.num_eval_samples, 1, 1, 1])
            logits = model(images, training=False)
            probs = tf.nn.softmax(logits)

            if FLAGS.num_eval_samples > 1:
                probs = tf.reshape(
                    probs,
                    tf.concat([[FLAGS.num_eval_samples, -1], probs.shape[1:]],
                              0))
                probs = tf.reduce_mean(probs, 0)

            if FLAGS.ensemble_size > 1:
                per_probs = tf.split(probs,
                                     num_or_size_splits=FLAGS.ensemble_size,
                                     axis=0)
                if dataset_name == 'clean':
                    per_probs_tensor = tf.reshape(
                        probs,
                        tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]],
                                  0))
                    diversity_results = ed.metrics.average_pairwise_diversity(
                        per_probs_tensor, FLAGS.ensemble_size)

                    for k, v in diversity_results.items():
                        test_diversity['test/' + k].update_state(v)

                    for i in range(FLAGS.ensemble_size):
                        member_probs = per_probs[i]
                        member_nll = tf.keras.losses.sparse_categorical_crossentropy(
                            labels, member_probs)
                        metrics['test/nll_member_{}'.format(i)].update_state(
                            member_nll)
                        metrics['test/accuracy_member_{}'.format(
                            i)].update_state(labels, member_probs)

                probs = tf.reduce_mean(per_probs, axis=0)

            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
            filtered_variables = []
            for var in model.trainable_variables:
                if 'kernel' in var.name or 'bias' in var.name:
                    filtered_variables.append(tf.reshape(var, (-1, )))

            kl = sum(model.losses) / test_dataset_size
            l2_loss = kl + FLAGS.l2 * 2 * tf.nn.l2_loss(
                tf.concat(filtered_variables, axis=0))
            loss = negative_log_likelihood + l2_loss
            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
                metrics['test/loss'].update_state(loss)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            work_unit.set_notes(message)
            if step % 20 == 0:
                logging.info(message)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator, dataset_name)
            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types, max_intensity)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        for i in range(FLAGS.ensemble_size):
            logging.info(
                'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
                metrics['test/nll_member_{}'.format(i)].result(),
                metrics['test/accuracy_member_{}'.format(i)].result() * 100)
        total_metrics = itertools.chain(metrics.items(),
                                        training_diversity.items(),
                                        test_diversity.items())
        total_results = {
            name: metric.result()
            for name, metric in total_metrics
        }
        total_results.update(corrupt_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for name, result in total_results.items():
            name = name.replace('/', '_')
            if 'negative_log_likelihood' in name:
                # Plots sort WIDs from high-to-low so look at maximization objectives.
                name = name.replace('negative_log_likelihood',
                                    'log_likelihood')
                result = -result
            objective = work_unit.get_measurement_series(name)
            objective.create_measurement(result, epoch + 1)

        for _, metric in total_metrics:
            metric.reset_states()
        summary_writer.flush()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
def main(_):
    tf.enable_v2_behavior()
    ##############################################################################
    ######################### Data loading and processing ########################
    ##############################################################################
    print('Loading data')

    with gfile.GFile(transition_path, 'r') as f:
        transitions = np.load(f)
    if np.max(transitions) > 1.0:
        transitions = transitions / 255.0
    with gfile.GFile(synthetic_transition_path, 'r') as f:
        synthetic_transitions = np.load(f)
    if np.max(synthetic_transitions) > 1.0:
        synthetic_transitions = synthetic_transitions / 255.0

    with gfile.GFile(transition_label_path, 'r') as f:
        captions = pickle.load(f)
    with gfile.GFile(synthetic_transition_label_path, 'r') as f:
        synthetic_captions = pickle.load(f)

    with gfile.GFile(vocab_path, 'r') as f:
        vocab_list = f.readlines()

    vocab_list = [w[:-1].decode('utf-8') for w in vocab_list]
    vocab_list = ['eos', 'sos'] + vocab_list

    v2i, i2v = wv.create_look_up_table(vocab_list)
    encode_fn = wv.encode_text_with_lookup_table(v2i)
    decode_fn = wv.decode_with_lookup_table(i2v)

    encoded_captions = []
    for all_cp in captions:
        for cp in all_cp:
            cp = 'sos ' + cp + ' eos'
            encoded_captions.append(np.array(encode_fn(cp)))

    synthetic_encoded_captions = []
    for all_cp in synthetic_captions:
        for cp in all_cp:
            cp = 'sos ' + cp + ' eos'
            synthetic_encoded_captions.append(np.array(encode_fn(cp)))

    all_caption_n = len(encoded_captions)
    all_synthetic_caption_n = len(synthetic_encoded_captions)

    encoded_captions = np.array(encoded_captions)
    encoded_captions = pad_to_max_length(encoded_captions, max_l=15)

    synthetic_encoded_captions = np.array(synthetic_encoded_captions)
    synthetic_encoded_captions = pad_to_max_length(synthetic_encoded_captions,
                                                   max_l=15)

    obs_idx, caption_idx, negative_caption_idx = [], [], []
    curr_caption_idx = 0
    for i, _ in enumerate(transitions):
        for cp in captions[i]:
            obs_idx.append(i)
            if 'nothing' not in cp:
                caption_idx.append(curr_caption_idx)
            else:
                negative_caption_idx.append(curr_caption_idx)
            curr_caption_idx += 1
    assert curr_caption_idx == all_caption_n

    synthetic_obs_idx, synthetic_caption_idx = [], []
    synthetic_negative_caption_idx = []
    curr_caption_idx = 0
    for i, _ in enumerate(synthetic_transitions):
        for cp in synthetic_captions[i]:
            synthetic_obs_idx.append(i)
            if 'nothing' not in cp:
                synthetic_caption_idx.append(curr_caption_idx)
            else:
                synthetic_negative_caption_idx.append(curr_caption_idx)
            curr_caption_idx += 1
    assert curr_caption_idx == all_synthetic_caption_n

    obs_idx = np.array(obs_idx)
    caption_idx = np.array(caption_idx)
    negative_caption_idx = np.array(negative_caption_idx)
    all_idx = np.arange(len(caption_idx))
    train_idx = all_idx[:int(len(all_idx) * 0.8)]
    test_idx = all_idx[int(len(all_idx) * 0.8):]
    print('Number of training examples: {}'.format(len(train_idx)))
    print('Number of test examples: {}\n'.format(len(test_idx)))

    synthetic_obs_idx = np.array(synthetic_obs_idx)
    synthetic_caption_idx = np.array(synthetic_caption_idx)
    synthetic_negative_caption_idx = np.array(synthetic_negative_caption_idx)
    synthetic_all_idx = np.arange(len(synthetic_caption_idx))
    synthetic_train_idx = synthetic_all_idx[:int(len(synthetic_all_idx) * 0.8)]
    synthetic_test_idx = synthetic_all_idx[int(len(synthetic_all_idx) * 0.8):]
    print('Number of synthetic training examples: {}'.format(
        len(synthetic_train_idx)))
    print('Number of synthetic test examples: {}\n'.format(
        len(synthetic_test_idx)))

    def sample_batch(data_type, batch_size, mode='train'):
        is_synthetic = data_type == 'synthetic'
        transitions_s = synthetic_transitions if is_synthetic else transitions
        encoded_captions_s = synthetic_encoded_captions if is_synthetic else encoded_captions
        obs_idx_s = synthetic_obs_idx if is_synthetic else obs_idx
        caption_idx_s = synthetic_caption_idx if is_synthetic else caption_idx
        all_idx_s = synthetic_all_idx if is_synthetic else all_idx
        train_idx_s = synthetic_train_idx if is_synthetic else train_idx
        test_idx_s = synthetic_test_idx if is_synthetic else test_idx
        if mode == 'train':
            batch_idx_s = np.random.choice(train_idx_s, size=batch_size)
        else:
            batch_idx_s = np.random.choice(test_idx_s, size=batch_size)
        input_tensor = tf.convert_to_tensor(
            np.concatenate([
                transitions_s[obs_idx_s[batch_idx_s], 1, :],
                transitions_s[obs_idx_s[batch_idx_s], 1, :]
            ]))
        positive_idx = caption_idx_s[batch_idx_s]
        negative_idx = caption_idx_s[np.random.choice(train_idx_s,
                                                      size=batch_size)]
        caption_tensor = tf.convert_to_tensor(
            np.concatenate([
                encoded_captions_s[positive_idx],
                encoded_captions_s[negative_idx]
            ],
                           axis=0))
        target_tensor = tf.convert_to_tensor(
            np.float32(
                np.concatenate([np.ones(batch_size),
                                np.zeros(batch_size)],
                               axis=0)))
        return input_tensor, caption_tensor, target_tensor

    ##############################################################################
    ############################# Training Setup #################################
    ##############################################################################
    embedding_dim = 32
    units = 64
    vocab_size = len(vocab_list)
    batch_size = 64
    max_sequence_length = 15

    encoder_config = {'name': 'image', 'embedding_dim': 64}
    decoder_config = {
        'name': 'attention',
        'word_embedding_dim': 64,
        'hidden_units': 256,
        'vocab_size': len(vocab_list),
    }

    encoder = get_answering_encoder(encoder_config)
    decoder = get_answering_decoder(decoder_config)
    projection_layer = tf.keras.layers.Dense(1,
                                             activation='sigmoid',
                                             name='answering_projection')

    optimizer = tf.keras.optimizers.Adam(1e-4)
    bce = tf.keras.losses.BinaryCrossentropy()

    @tf.function
    def compute_loss(obs, instruction, target, training):
        print('Build compute loss...')
        instruction = tf.expand_dims(instruction, axis=-1)
        hidden = decoder.reset_state(batch_size=target.shape[0])
        features = encoder(obs, training=training)
        for i in tf.range(max_sequence_length):
            _, hidden, _ = decoder(instruction[:, i],
                                   features,
                                   hidden,
                                   training=training)
        projection = tf.squeeze(projection_layer(hidden), axis=1)
        loss = bce(target, projection)
        return loss, projection

    @tf.function
    def train_step(obs, instruction, target):
        print('Build train step...')
        with tf.GradientTape() as tape:
            loss, _ = compute_loss(obs, instruction, target, True)
        trainable_variables = encoder.trainable_variables + decoder.trainable_variables + projection_layer.trainable_variables
        print('num trainable: ', len(trainable_variables))
        gradients = tape.gradient(loss, trainable_variables)
        optimizer.apply_gradients(zip(gradients, trainable_variables))
        return loss

    ##############################################################################
    ############################# Training Loop ##################################
    ##############################################################################
    print('Start training...\n')
    start_epoch = 0
    if FLAGS.save_dir:
        checkpoint_path = FLAGS.save_dir
        ckpt = tf.train.Checkpoint(encoder=encoder,
                                   decoder=decoder,
                                   projection_layer=projection_layer,
                                   optimizer=optimizer)
        ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                  checkpoint_path,
                                                  max_to_keep=5)
        if ckpt_manager.latest_checkpoint:
            start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])

    epochs = 400
    step_per_epoch = int(all_caption_n / batch_size)

    previous_best, previous_best_accuracy = 100., 0.0
    # input_tensor, instruction, target = sample_batch('synthetic', batch_size,
    #                                                  'train')
    for epoch in range(start_epoch, epochs):
        start = time.time()
        total_loss = 0
        for batch in range(step_per_epoch):
            input_tensor, instruction, target = sample_batch(
                'synthetic', batch_size, 'train')
            batch_loss = train_step(input_tensor, instruction, target)
            total_loss += batch_loss
            # print(batch, batch_loss)
            # print(instruction[0])
            # print(encode_fn('nothing'))
            # print('====================================')

            if batch % 1000 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(
                    epoch, batch, batch_loss.numpy()))

        if epoch % 5 == 0 and FLAGS.save_dir:
            test_total_loss = 0
            accuracy = 0
            for batch in range(10):
                input_tensor, instruction, target = sample_batch(
                    'synthetic', batch_size, 'test')
                t_loss, prediction = compute_loss(input_tensor, instruction,
                                                  target, False)
                test_total_loss += t_loss
                accuracy += np.mean(
                    np.float32(np.float32(prediction > 0.5) == target))
            test_total_loss /= 10.
            accuracy /= 10.
            if accuracy > previous_best_accuracy:
                previous_best_accuracy, previous_best = accuracy, test_total_loss
                ckpt_manager.save(checkpoint_number=epoch)

        print('\nEpoch {} | Loss {:.6f} | Val loss {:.6f} | Accuracy {:.3f}'.
              format(epoch + 1, total_loss / step_per_epoch, previous_best,
                     previous_best_accuracy))
        print('Time taken for 1 epoch {:.6f} sec\n'.format(time.time() -
                                                           start))

        if epoch % 10 == 0:
            test_total_loss = 0
            accuracy = 0
            for batch in range(len(test_idx) // batch_size):
                input_tensor, instruction, target = sample_batch(
                    'synthetic', batch_size, 'test')
                t_loss, prediction = compute_loss(input_tensor,
                                                  instruction,
                                                  target,
                                                  training=False)
                test_total_loss += t_loss
                accuracy += np.mean(
                    np.float32(np.float32(prediction > 0.5) == target))
            test_total_loss /= (len(test_idx) // batch_size)
            accuracy /= (len(test_idx) // batch_size)
            if accuracy > previous_best_accuracy and FLAGS.save_dir:
                previous_best_accuracy, previous_best = accuracy, test_total_loss
                ckpt_manager.save(checkpoint_number=epoch)
            print('\n====================================================')
            print('Test Loss {:.6f} | Test Accuracy {:.3f}'.format(
                test_total_loss, accuracy))
            print('====================================================\n')
Exemple #17
0
def main(unused_argv):
  assert FLAGS.data is not None, 'Provide training data path via --data.'
  tf.enable_v2_behavior()

  batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE

  training_steps_per_epoch = FLAGS.steps_per_epoch or (
      int(APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
  validation_steps = int(
      math.ceil(1.0 * IMAGENET_VALIDATION_IMAGES / batch_size))

  model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
  logging.info('Saving tensorboard summaries at %s', model_dir)

  logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local')
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
  tf.config.experimental_connect_to_host(resolver.master())  # pylint: disable=line-too-long
  tf.tpu.experimental.initialize_tpu_system(resolver)
  strategy = tf.distribute.experimental.TPUStrategy(resolver)

  logging.info('Use bfloat16: %s.', USE_BFLOAT16)
  logging.info('Use global batch size: %s.', batch_size)
  logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)
  logging.info('Training model using data in directory "%s".', FLAGS.data)

  with tf.device('/job:worker'):
    with strategy.scope():
      logging.info('Building Keras ResNet-50 model')
      model = resnet_model.ResNet50(num_classes=NUM_CLASSES)

      logging.info('Compiling model.')
      metrics = ['sparse_categorical_accuracy']

      if FLAGS.eval_top_5_accuracy:
        metrics.append(sparse_top_k_categorical_accuracy)

      model.compile(
          optimizer=tf.keras.optimizers.SGD(
              learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
          loss='sparse_categorical_crossentropy',
          metrics=metrics)

    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True, data_dir=FLAGS.data, batch_size=batch_size,
        use_bfloat16=USE_BFLOAT16)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False, data_dir=FLAGS.data, batch_size=batch_size,
        use_bfloat16=USE_BFLOAT16)

    lr_schedule_cb = LearningRateBatchScheduler(
        schedule=learning_rate_schedule_wrapper(training_steps_per_epoch))
    tensorboard_cb = tf.keras.callbacks.TensorBoard(
        log_dir=model_dir)

    training_callbacks = [lr_schedule_cb, tensorboard_cb]

    model.fit(
        imagenet_train.input_fn(),
        epochs=FLAGS.num_epochs,
        steps_per_epoch=training_steps_per_epoch,
        callbacks=training_callbacks,
        validation_data=imagenet_eval.input_fn(),
        validation_steps=validation_steps,
        validation_freq=5)

    model_saving_utils.save_model(model, model_dir, WEIGHTS_TXT)
Exemple #18
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    emb_size = FLAGS.emb_size

    if FLAGS.arch == 'resnet50':
        model_def = model_resnet.ResNet50.partial(num_outputs=emb_size)
        feature_size = 64 * 8 * 4
    elif FLAGS.arch == 'resnet101':
        model_def = model_resnet.ResNet101.partial(num_outputs=emb_size)
        feature_size = 64 * 8 * 4
    elif FLAGS.arch == 'resnet152':
        model_def = model_resnet.ResNet152.partial(num_outputs=emb_size)
        feature_size = 64 * 8 * 4
    else:
        raise ValueError

    if FLAGS.lr_moco_sched_steps:
        lr_moco_sched_steps = ast.literal_eval(FLAGS.lr_moco_sched_steps)
    else:
        lr_moco_sched_steps = [[120, 0.1], [160, 0.01]]

    if FLAGS.lr_clf_sched_steps:
        lr_clf_sched_steps = ast.literal_eval(FLAGS.lr_clf_sched_steps)
    else:
        lr_clf_sched_steps = [[60, 0.2], [75, 0.04], [90, 0.008]]

    def make_moco_lr_fun(base_lr, steps_per_epoch):
        return lr_schedule.create_stepped_learning_rate_schedule(
            base_lr,
            steps_per_epoch,
            lr_moco_sched_steps,
            warmup_length=FLAGS.lr_moco_sched_warmup)

    def make_clf_lr_fun(base_lr, steps_per_epoch):
        return lr_schedule.create_stepped_learning_rate_schedule(
            base_lr,
            steps_per_epoch,
            lr_clf_sched_steps,
            warmup_length=FLAGS.lr_clf_sched_warmup)

    train(model_def,
          model_dir=FLAGS.model_dir,
          batch_size=FLAGS.batch_size,
          eval_batch_size=FLAGS.eval_batch_size,
          num_moco_epochs=FLAGS.num_moco_epochs,
          num_clf_epochs=FLAGS.num_clf_epochs,
          moco_learning_rate=FLAGS.moco_learning_rate,
          clf_learning_rate=FLAGS.clf_learning_rate,
          sgd_momentum=FLAGS.sgd_momentum,
          sgd_nesterov=FLAGS.sgd_nesterov,
          make_moco_lr_fun=make_moco_lr_fun,
          make_clf_lr_fun=make_clf_lr_fun,
          moco_l2_reg=FLAGS.moco_l2_reg,
          clf_l2_reg=FLAGS.clf_l2_reg,
          feature_size=feature_size,
          moco_momentum=FLAGS.moco_momentum,
          emb_size=emb_size,
          moco_temperature=FLAGS.moco_temperature,
          dictionary_size=FLAGS.dictionary_size,
          run_seed=FLAGS.rng)
def main(argv):
    del argv
    # BEGIN GOOGLE-INTERNAL
    xm.setup_work_unit()
    # END GOOGLE-INTERNAL

    tf.enable_v2_behavior()
    init_mllogger()

    mllogger.event('cache_clear')
    mllogger.start('init_start')
    mllogger.event('submission_org', 'Google')
    mllogger.event('submission_platform',
                   'TPUv3-{}'.format(jax.device_count()))
    mllogger.event('submission_division', 'closed')
    mllogger.event('submission_status', 'research')
    mllogger.event('submission_benchmark', 'resnet')
    mllogger.event('train_samples', input_pipeline.TRAIN_IMAGES)
    mllogger.event('eval_samples', input_pipeline.EVAL_IMAGES)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir)
        # Write summaries in background thread to avoid blocking on device sync
        summary_thread = thread.ThreadPoolExecutor(1, 'summary')
    # Infeed is currently synchronous, so do it in a background thread too
    infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed')

    if FLAGS.seed is not None:
        seed = FLAGS.seed
    else:
        seed = np.uint32(time.time() if jax.host_id() == 0 else 0)
        seed = per_host_sum_pmap(seed)

    mllogger.event('seed', int(seed))
    key = random.PRNGKey(seed)

    batch_size = FLAGS.batch_size
    if batch_size == -1:
        if jax.device_count() > 4096:
            batch_size = 65536
        else:
            batch_size = min(128 * jax.device_count(), 32768)
    mllogger.event('global_batch_size', batch_size)
    eval_batch_size = min(input_pipeline.EVAL_IMAGES, 256 * jax.device_count())
    device_batch_size = batch_size // jax.device_count()
    device_eval_batch_size = int(
        math.ceil(eval_batch_size / jax.device_count()))

    model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32
    input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32

    num_epochs = FLAGS.num_epochs
    if num_epochs is None:
        if batch_size < 32768:
            num_epochs = 56
        elif batch_size < 65536:
            num_epochs = 64
        else:
            num_epochs = 92

    steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size
    # match TF submission behavior (round steps per loop up)
    steps_per_loop = int(math.ceil(steps_per_epoch * FLAGS.epochs_per_loop))
    # also apply rounding loop up to next step to "epochs" in LR schedule
    steps_per_epoch *= steps_per_loop / (steps_per_epoch *
                                         FLAGS.epochs_per_loop)

    steps_per_eval = int(
        math.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size))

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.
    beta = FLAGS.momentum
    if beta is None:
        if batch_size < 32768:
            beta = 0.9
        elif batch_size < 65536:
            beta = 0.929
        else:
            beta = 0.9537213777059405
    weight_decay = FLAGS.weight_decay
    if weight_decay is None:
        weight_decay = 2e-4 if batch_size < 32768 else 1e-4

    space_to_depth = FLAGS.space_to_depth
    if space_to_depth is None:
        space_to_depth = device_batch_size <= 8

    image_format = FLAGS.image_format
    if image_format is None:
        if space_to_depth and device_batch_size <= 8:
            image_format = 'HWNC'
        else:
            image_format = 'HWCN'

    image_size = input_pipeline.IMAGE_SIZE
    if space_to_depth:
        train_input_shape = (device_batch_size, image_size // 2,
                             image_size // 2, 12)
        eval_input_shape = (device_eval_batch_size, image_size // 2,
                            image_size // 2, 12)
    else:
        train_input_shape = (device_batch_size, image_size, image_size, 3)
        eval_input_shape = (device_eval_batch_size, image_size, image_size, 3)
    if image_format == 'HWCN':
        train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 3, 0])
        eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 3, 0])
    elif image_format == 'HWNC':
        train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 0, 3])
        eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 0, 3])

    model, state = create_model(key, device_batch_size, image_size,
                                model_dtype, space_to_depth)

    if FLAGS.lars:
        mllogger.event('opt_name', 'lars')
        mllogger.event('lars_opt_weight_decay', weight_decay)
        mllogger.event('lars_opt_momentum', beta)
        mllogger.event('lars_epsilon', 0)
        weight_opt_def = optim.LARS(base_learning_rate,
                                    beta,
                                    weight_decay=weight_decay)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=False)
        learning_rate_fn = polynomial_learning_rate_fn(batch_size,
                                                       steps_per_epoch,
                                                       num_epochs)
    else:
        mllogger.event('opt_name', 'sgd')
        mllogger.event('sgd_opt_momentum', beta)
        weight_opt_def = optim.Momentum(base_learning_rate,
                                        beta,
                                        weight_decay=weight_decay,
                                        nesterov=True)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=True)
        learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate,
                                                      steps_per_epoch,
                                                      num_epochs)

    def filter_weights(key, _):
        return 'bias' not in key and 'scale' not in key

    def filter_other(key, _):
        return 'bias' in key or 'scale' in key

    weight_traversal = optim.ModelParamTraversal(filter_weights)
    other_traversal = optim.ModelParamTraversal(filter_other)
    optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    optimizer = optimizer_def.create(model)
    del model  # do not keep a copy of the initial model

    optimizer = broadcast(optimizer)
    state = broadcast(state)
    empty_metrics = broadcast({'samples': 0, 'loss': 0., 'accuracy': 0})

    p_allreduce_metrics = jax.pmap(allreduce_metrics, axis_name='batch')

    p_sync_batchnorm_stats = jax.pmap(sync_batchnorm_stats, axis_name='batch')

    def host_loop_train_step(optimizer, state, metrics):
        token = lax.create_token(optimizer.state[0].step)
        batch, token = lax.infeed(token,
                                  shape=(jax.ShapedArray(
                                      train_input_shape, model_dtype),
                                         jax.ShapedArray((device_batch_size, ),
                                                         jnp.int32)))
        optimizer, state, metrics = train_step(optimizer, state, batch,
                                               metrics, learning_rate_fn,
                                               image_format, space_to_depth)
        return optimizer, state, metrics

    p_host_loop_train_step = jax.pmap(host_loop_train_step,
                                      axis_name='batch',
                                      in_axes=(None, 0, 0))

    def host_loop_eval_step(model, state, metrics):
        token = lax.create_token(metrics['samples'])
        batch, token = lax.infeed(
            token,
            shape=(jax.ShapedArray(eval_input_shape, model_dtype),
                   jax.ShapedArray((device_eval_batch_size, ), jnp.int32)))
        metrics = eval_step(model, state, batch, metrics, image_format,
                            space_to_depth)
        return metrics

    p_host_loop_eval_step = jax.pmap(host_loop_eval_step,
                                     axis_name='batch',
                                     in_axes=(None, None, 0))

    def device_train_loop_cond(args):
        _, _, _, _, step, loop = args
        return step // steps_per_loop == loop

    def device_train_loop_body(args):
        optimizer, state, metrics, token, step, loop = args
        batch, token = lax.infeed(token,
                                  shape=(jax.ShapedArray(
                                      train_input_shape, model_dtype),
                                         jax.ShapedArray((device_batch_size, ),
                                                         jnp.int32)))
        optimizer, state, metrics = train_step(optimizer, state, batch,
                                               metrics, learning_rate_fn,
                                               image_format, space_to_depth)
        step += 1
        return optimizer, state, metrics, token, step, loop

    def device_train_loop(optimizer, state, metrics, step, loop):
        token = lax.create_token(step)
        optimizer, state, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, state, metrics, token, step, loop))
        state = sync_batchnorm_stats(state)
        metrics = allreduce_metrics(metrics)
        return optimizer, state, metrics, step

    p_train_loop = jax.pmap(device_train_loop,
                            axis_name='batch',
                            in_axes=(None, None, 0, None, None))

    # BEGIN GOOGLE-INTERNAL
    def maybe_start_xprof(seconds):
        if jax.host_id() == 0 and FLAGS.xprof:
            xprof = xprof_session.XprofSession()
            xprof.start_session('REDACTED', True, 2)

            def sleep_and_end_xprof():
                time.sleep(seconds)
                logging.info(
                    'Xprof URL: %s',
                    xprof.end_session_and_get_url(
                        tag='flax resnet, {} devices, batch {} per device'.
                        format(jax.device_count(), device_batch_size)))

            thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof)

    # END GOOGLE-INTERNAL

    if FLAGS.precompile:
        logging.info('precompiling step/loop functions')
        if FLAGS.device_loop:
            # the device training loop condition will immediately be false
            p_train_loop(unbroadcast(optimizer), unbroadcast(state),
                         empty_metrics, jnp.array(0, dtype=jnp.int32), 1)
        else:
            for device in jax.local_devices():
                images = np.zeros(train_input_shape, model_dtype)
                labels = np.zeros((device_batch_size, ), np.int32)
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
            p_host_loop_train_step(unbroadcast(optimizer), state,
                                   empty_metrics)
            p_sync_batchnorm_stats(state)
        for device in jax.local_devices():
            images = np.zeros(eval_input_shape, model_dtype)
            labels = np.zeros((device_eval_batch_size, ), np.int32)
            infeed_pool.submit(
                partial(device.transfer_to_infeed, (images, labels)))
        p_host_loop_eval_step(unbroadcast(optimizer.target),
                              unbroadcast(state), empty_metrics)
        p_allreduce_metrics(empty_metrics)['accuracy'].block_until_ready()
        logging.info('finished precompiling')

    # BEGIN GOOGLE-INTERNAL
    maybe_start_xprof(20)
    # END GOOGLE-INTERNAL
    if not FLAGS.fake_data:
        logging.info('constructing datasets')
        # pylint: disable=g-complex-comprehension
        train_ds, eval_ds = [
            input_pipeline.load_split(
                device_batch_size if train else device_eval_batch_size,
                dtype=input_dtype,
                train=train,
                image_format=image_format,
                space_to_depth=space_to_depth,
                cache_uncompressed=jax.device_count() > 64)
            for train in (True, False)
        ]
        logging.info('constructing dataset iterators')
        train_iter = iter(train_ds)
        eval_iter = iter(eval_ds)

    local_devices = jax.local_devices()
    host_step, device_step = 0, broadcast(0)
    mllogger.end('init_stop')
    mllogger.start('run_start')
    mllogger.start('block_start',
                   metadata={
                       'first_epoch_num': 1,
                       'epoch_count': FLAGS.epochs_per_loop
                   })
    for loop in range(int(math.ceil(num_epochs / FLAGS.epochs_per_loop)) + 2):
        # BEGIN GOOGLE-INTERNAL
        if loop == 10: maybe_start_xprof(1)
        # END GOOGLE-INTERNAL
        metrics = empty_metrics
        if FLAGS.device_loop:
            optimizer, state, metrics, device_step = p_train_loop(
                unbroadcast(optimizer), unbroadcast(state), metrics,
                unbroadcast(device_step), loop)
        while int(host_step // steps_per_loop) == loop:
            if not FLAGS.device_loop:
                optimizer, state, metrics = p_host_loop_train_step(
                    unbroadcast(optimizer), state, metrics)
            # pylint: disable=protected-access
            while infeed_pool._work_queue.qsize() > 100:
                time.sleep(0.01)
            for device in local_devices:
                if FLAGS.fake_data:
                    images = np.zeros(train_input_shape, model_dtype)
                    labels = np.zeros((device_batch_size, ), np.int32)
                else:
                    # pylint: disable=protected-access
                    images, labels = jax.tree_map(lambda x: x._numpy(),
                                                  next(train_iter))
                assert images.shape == train_input_shape and labels.dtype == jnp.int32
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
            host_step += 1
        epoch = (loop + 1) * FLAGS.epochs_per_loop
        if FLAGS.train_metrics:
            if not FLAGS.device_loop:
                metrics = p_allreduce_metrics(metrics)
            if jax.host_id() == 0:
                summary_thread.submit(
                    partial(write_summary, summary_writer, metrics, 'train',
                            epoch))
        if not FLAGS.device_loop:
            state = p_sync_batchnorm_stats(state)
        metrics = empty_metrics
        for _ in range(steps_per_eval):
            metrics = p_host_loop_eval_step(unbroadcast(optimizer.target),
                                            unbroadcast(state), metrics)
            for device in local_devices:
                if FLAGS.fake_data:
                    images = np.zeros(eval_input_shape, model_dtype)
                    labels = np.zeros((device_eval_batch_size, ), np.int32)
                else:
                    # pylint: disable=protected-access
                    images, labels = jax.tree_map(lambda x: x._numpy(),
                                                  next(eval_iter))
                assert images.shape == eval_input_shape and labels.dtype == jnp.int32, \
                    'images.shape={}'.format(images.shape)
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
        metrics = p_allreduce_metrics(metrics)
        if jax.host_id() == 0:
            summary_thread.submit(
                partial(write_summary, summary_writer, metrics, 'eval', epoch))
    # Wait until computations are done before exiting
    p_allreduce_metrics(metrics)['accuracy'].block_until_ready()
    if jax.host_id() == 0:
        summary_thread.shutdown()
        if not DONE:
            mllogger.end('run_stop', metadata={'status': 'aborted'})
Exemple #20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    dynamic_scale = None
    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
            dynamic_scale = optim.DynamicScale()
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = imagenet_train_utils.create_input_iter(local_batch_size,
                                                        FLAGS.data_dir,
                                                        image_size,
                                                        input_dtype,
                                                        train=True,
                                                        cache=FLAGS.cache)
    eval_iter = imagenet_train_utils.create_input_iter(local_batch_size,
                                                       FLAGS.data_dir,
                                                       image_size,
                                                       input_dtype,
                                                       train=False,
                                                       cache=FLAGS.cache)

    # Create the hyperparameter object
    if FLAGS.hparams_config_dict:
        # In this case, there are multiple training configs defined in the config
        # dict, so we pull out the one this training run should use.
        if 'configs' in FLAGS.hparams_config_dict:
            hparams_config_dict = FLAGS.hparams_config_dict.configs[
                FLAGS.config_idx]
        else:
            hparams_config_dict = FLAGS.hparams_config_dict
        hparams = os_hparams_utils.load_hparams_from_config_dict(
            hparams_config.TrainingHParams, models.ResNet.HParams,
            hparams_config_dict)
    else:
        raise ValueError('Please provide a base config dict.')

    os_hparams_utils.write_hparams_to_file_with_host_id_check(
        hparams, FLAGS.model_dir)

    # get num_epochs from hparam instead of FLAGS
    num_epochs = hparams.lr_scheduler.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    # Estimate compute / memory costs
    if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost:
        estimate_compute_and_memory_cost(image_size=image_size,
                                         model_dir=FLAGS.model_dir,
                                         hparams=hparams)
        logging.info(
            'Writing training HLO and estimating compute/memory costs.')

    model, variables = imagenet_train_utils.create_model(
        rng,
        device_batch_size,
        image_size,
        model_dtype,
        hparams=hparams.model_hparams,
        train=True)
    model_state, params = variables.pop('params')
    if hparams.optimizer == 'sgd':
        optimizer = optim.Momentum(beta=hparams.momentum,
                                   nesterov=True).create(params)
    elif hparams.optimizer == 'adam':
        optimizer = optim.Adam(beta1=hparams.adam.beta1,
                               beta2=hparams.adam.beta2).create(params)
    else:
        raise ValueError('Optimizer type is not supported.')
    state = imagenet_train_utils.TrainState(step=0,
                                            optimizer=optimizer,
                                            model_state=model_state,
                                            dynamic_scale=dynamic_scale)
    del params, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    base_learning_rate = hparams.base_learning_rate * batch_size / 256.
    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch,
                                               hparams.lr_scheduler)

    p_train_step = jax.pmap(functools.partial(
        imagenet_train_utils.train_step,
        model,
        learning_rate_fn=learning_rate_fn),
                            axis_name='batch',
                            static_broadcasted_argnums=(2, 3))
    p_eval_step = jax.pmap(functools.partial(imagenet_train_utils.eval_step,
                                             model),
                           axis_name='batch')

    epoch_metrics = []
    state_dict_summary_all = []
    state_dict_keys = _get_state_dict_keys_from_flags()
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps:
            break
        update_bounds = train_utils.should_update_bounds(
            hparams.activation_bound_update_freq,
            hparams.activation_bound_start_step, step)
        state, metrics = p_train_step(state, batch, hparams, update_bounds)

        state_dict_summary = summary_utils.get_state_dict_summary(
            state.model_state, state_dict_keys)
        state_dict_summary_all.append(state_dict_summary)

        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()

            # Write to TensorBoard
            state_dict_summary_all = common_utils.get_metrics(
                state_dict_summary_all)
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

                summary_utils.write_state_dict_summaries_to_tb(
                    state_dict_summary_all, summary_writer,
                    FLAGS.state_dict_summary_freq, step)

            state_dict_summary_all = []
            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = imagenet_train_utils.sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = imagenet_train_utils.sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemple #21
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config = FLAGS.config
    logging.info('===========Config Dict============')
    logging.info(config)
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    num_train_steps = config.num_train_steps
    num_eval_steps = config.num_eval_steps
    eval_freq = config.eval_frequency
    random_seed = config.random_seed
    model_type = config.model_type

    if jax.process_index() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'summary'))
    else:
        summary_writer = None

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    logging.info('Training on %s', FLAGS.task_name)

    if model_type in ['wideresnet', 'resnet', 'simple_cnn']:
        normalize = True
    else:  # transformer-based models
        normalize = False
    (train_ds, eval_ds, test_ds, num_classes, vocab_size,
     input_shape) = task_registry.TASK_DATA_DICT[FLAGS.task_name](
         n_devices=jax.local_device_count(),
         batch_size=batch_size,
         normalize=normalize)
    train_iter = iter(train_ds)
    model_kwargs = {}
    flatten_input = True

    if model_type in ['wideresnet', 'resnet', 'simple_cnn']:
        model_kwargs.update({
            'num_classes': num_classes,
        })
        flatten_input = False

    else:  # transformer models
        # we will flatten the input
        bs, h, w, c = input_shape
        assert c == 1
        input_shape = (bs, h * w * c)
        model_kwargs.update({
            'vocab_size': vocab_size,
            'max_len': input_shape[1],
            'classifier': True,
            'num_classes': num_classes,
        })

    model_kwargs.update(config.model)

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    model, state = train_utils.get_model(model_type, create_model,
                                         model_kwargs, init_rng, input_shape)

    optimizer = create_optimizer(model, learning_rate, config.weight_decay)
    del model  # Don't keep a copy of the initial model.

    start_step = 0
    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer, state = checkpoints.restore_checkpoint(
            FLAGS.model_dir, (optimizer, state))
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer and state
    optimizer = jax_utils.replicate(optimizer)
    state = jax_utils.replicate(state)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors=config.factors,
        base_learning_rate=learning_rate,
        warmup_steps=config.warmup,
        steps_per_cycle=config.get('steps_per_cycle', None),
    )
    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        num_classes=num_classes,
        grad_clip_norm=config.get('grad_clip_norm', None),
        flatten_input=flatten_input),
                            axis_name='batch')

    p_eval_step = jax.pmap(
        functools.partial(eval_step,
                          num_classes=num_classes,
                          flatten_input=flatten_input),
        axis_name='batch',
    )

    optimizer, state, step = train_loop(config, dropout_rngs, eval_ds,
                                        eval_freq, num_eval_steps,
                                        num_train_steps, optimizer, state,
                                        p_eval_step, p_train_step, start_step,
                                        train_iter, summary_writer)

    logging.info('Starting testing')
    logging.info('====================')
    test(optimizer, state, p_eval_step, step, test_ds, summary_writer,
         FLAGS.model_dir)
                                               imagenet_eval, output_dir,
                                               metrics)

    save_model(model, output_dir, method, use_tpu, task_number)


def main(unused_argv):
    logging.info('Base LR: %s.', learning_rate_lib.BASE_LEARNING_RATE)
    logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)

    metrics = ['sparse_categorical_crossentropy']
    if FLAGS.eval_top_5_accuracy:
        metrics.append(sparse_top_k_categorical_accuracy)

    run(FLAGS.method,
        FLAGS.output_dir.replace('%task%', str(FLAGS.task)),
        task_number=FLAGS.task,
        use_tpu=FLAGS.use_tpu,
        tpu=FLAGS.tpu,
        metrics=metrics,
        fake_data=FLAGS.test_level > 1,
        fake_training=FLAGS.test_level > 0)


if __name__ == '__main__':

    tf.enable_v2_behavior()  # Required due to b/128610213.
    tf.logging.set_verbosity(tf.logging.INFO)
    _declare_flags()
    app.run(main)
Exemple #23
0
def main(argv):
    del argv  # unused arg
    tf.enable_v2_behavior()
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // batch_size
    steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    imagenet_train = utils.ImageNetInput(is_training=True,
                                         data_dir=FLAGS.data_dir,
                                         batch_size=FLAGS.per_core_batch_size,
                                         use_bfloat16=FLAGS.use_bfloat16)
    imagenet_eval = utils.ImageNetInput(is_training=False,
                                        data_dir=FLAGS.data_dir,
                                        batch_size=FLAGS.per_core_batch_size,
                                        use_bfloat16=FLAGS.use_bfloat16)
    test_datasets = {
        'clean':
        strategy.experimental_distribute_datasets_from_function(
            imagenet_eval.input_fn)
    }
    if FLAGS.corruptions_interval > 0:
        corruption_types, max_intensity = utils.load_corrupted_test_info()
        for name in corruption_types:
            for intensity in range(1, max_intensity + 1):
                dataset_name = '{0}_{1}'.format(name, intensity)
                corrupt_input_fn = utils.corrupt_test_input_fn(
                    batch_size=FLAGS.per_core_batch_size,
                    corruption_name=name,
                    corruption_intensity=intensity,
                    use_bfloat16=FLAGS.use_bfloat16)
                test_datasets[dataset_name] = (
                    strategy.experimental_distribute_datasets_from_function(
                        corrupt_input_fn))

    train_dataset = strategy.experimental_distribute_datasets_from_function(
        imagenet_train.input_fn)

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    with strategy.scope():
        logging.info('Building Keras ResNet-50 model')
        model = deterministic_model.resnet50(input_shape=(224, 224, 3),
                                             num_classes=NUM_CLASSES)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Scale learning rate and decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 256
        learning_rate = utils.LearningRateSchedule(steps_per_epoch, base_lr,
                                                   FLAGS.train_epochs,
                                                   _LR_SCHEDULE)
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                            momentum=0.9,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)
        }
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, max_intensity + 1):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        ed.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))

        logging.info('Finished building Keras ResNet-50 model')

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the weights. This excludes BN parameters and biases, but
                    # pay caution to their naming scheme.
                    if 'kernel' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].update_state(labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)

            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits, from_logits=True))
            probs = tf.nn.softmax(logits)
            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator, dataset_name)
            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types, max_intensity,
                FLAGS.alexnet_errors_path)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
from absl.testing import parameterized
from jax import random as jax_random
import numpy as np
import tensorflow.compat.v2 as real_tf
import tensorflow_probability as tfp

from discussion import fun_mcmc
from discussion.fun_mcmc import backend
from tensorflow_probability.python.internal import test_util as tfp_test_util

tf = backend.tf
tfb = tfp.bijectors
tfd = tfp.distributions
util = backend.util

real_tf.enable_v2_behavior()


def _test_seed():
  return tfp_test_util.test_seed() % (2**32 - 1)


def _no_compile(fn):
  return fn


def _fwd_mclachlan_optimal_4th_order_step(*args, **kwargs):
  return fun_mcmc.mclachlan_optimal_4th_order_step(
      *args, forward=True, **kwargs)

Exemple #25
0
def test_main():
  """Entrypoint for tests."""
  tf.enable_v2_behavior()
  tf.test.main()
Exemple #26
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    tf.enable_v2_behavior()
    run(FLAGS.prediction_path)
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.mkdir(FLAGS.save_dir)

  hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
  # Get hyperparmaters
  if FLAGS.xm_parameters:
    for key, value in json.loads(FLAGS.xm_parameters).items():
      if key not in hparam_str_dict:
        hparam_str_dict[key] = value

  hparam_str = ','.join(['%s=%s' % (k, str(hparam_str_dict[k])) for k in
                         sorted(hparam_str_dict.keys())])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
  char_id_table = {char: id for id, char in id_char_table.items()}
  id_token_table, token_id_table = dsl_tokens.build_token_tables()
  io_vocab_size = len(char_id_table) + 1  # For padding.
  program_vocab_size = len(token_id_table) + 1

  bos_token = token_id_table[dsl.BOS]
  eos_token = token_id_table[dsl.EOS]

  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    def decode_str(s):
      """Decode string tokens."""
      return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

    io_string = ''
    inps, outs = [], []
    for inp, out in zip(inputs, outputs):
      inps.append(decode_str(inp))
      outs.append(decode_str(out))
      io_string += inps[-1] + ' < ' + outs[-1] + ' > '
    return inps, outs, io_string[:-3]  # Remove last separator.

  def decode_program(program):
    """Decode program tokens."""
    program = program[:np.argmax(program == eos_token) + 1].astype(np.int32)
    try:
      p = dsl.decode_program(program, id_token_table)
      return p, p.to_string()
    except:  # pylint: disable=bare-except
      return None, ''  # Program does not compile.

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  dataset = input_pipeline.create_dataset_from_tf_record(
      FLAGS.dataset_filepattern, token_id_table, char_id_table)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]),
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
  train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
  train_iter = train_ds.as_numpy_iterator()

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=program_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      use_relative_attention=FLAGS.use_relative_attention,
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(
      shift=False, deterministic=True, decode=True)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  m = models.ProgramTransformer(eval_config)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))

  optimizer_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = train_lib.create_learning_rate_scheduler(
      base_learning_rate=FLAGS.lr)
  p_train_step = jax.pmap(
      functools.partial(
          train_lib.train_step,
          learning_rate_fn=learning_rate_fn,
          config=train_config),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(train_lib.eval_step, config=eval_config),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          train_lib.initialize_cache,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(train_lib.predict_step, config=predict_config),
      axis_name='batch',
      static_broadcasted_argnums=(4, 5, 6))

  # Main Train Loop
  # ---------------------------------------------------------------------------
  train_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  metrics_all = []
  tick = time.time()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, programs = common_utils.shard(next(train_iter))

    optimizer, metrics, train_rngs = p_train_step(
        optimizer, inputs, outputs, programs, train_rng=train_rngs)
    metrics_all.append(metrics)

    # Save a Checkpoint
    if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or
        step == FLAGS.num_train_steps - 1):
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)

    # Periodic metric handling.
    if not step or step % FLAGS.log_freq != 0:
      continue

    logging.info('Gathering training metrics.')
    # Training Metrics
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums)
    summary['learning_rate'] = lr
    # Calculate (clipped) perplexity after averaging log-perplexities:
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

    if jax.host_id() == 0:
      logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
      tock = time.time()
      steps_per_sec = FLAGS.log_freq / (tock - tick)
      tick = tock
      summary_writer.scalar('train/steps per second', steps_per_sec, step)
      for key, val in summary.items():
        summary_writer.scalar('train/' + key, val, step)
      summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    metrics_all = []

    # Evaluation Metrics
    logging.info('Gathering evaluation metrics.')
    t_evaluation_start = time.time()
    eval_metrics = []
    for batches in eval_ds.as_numpy_iterator():
      inputs, outputs, programs = common_utils.shard(batches)

      metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
      eval_metrics.append(metrics)

    eval_metrics = common_utils.get_metrics(eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
    eval_denominator = eval_metrics_sums.pop('denominator')
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics_sums)

    if jax.host_id() == 0:
      logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                   time.time()-t_evaluation_start, step, eval_summary['loss'])
      for key, val in eval_summary.items():
        summary_writer.scalar('eval/' + key, val, step)
      summary_writer.flush()

    # Beam search metrics.
    logging.info('Gathering beam search metrics.')
    for beam_size in [10, 100]:
      t_inference_start = time.time()
      pred_acc = 0
      pred_denominator = 0

      ios, targets, predictions = [], [], []
      for batches in predict_ds.as_numpy_iterator():
        pred_batch = batches
        # Handle final odd-sized batch by padding instead of dropping it.
        cur_pred_batch_size = pred_batch[0].shape[0]
        if cur_pred_batch_size % n_devices:
          padded_size = int(
              np.ceil(cur_pred_batch_size / n_devices) * n_devices)
          # pylint: disable=cell-var-from-loop
          pred_batch = jax.tree_map(
              lambda x: train_lib.pad_examples(x, padded_size), pred_batch)
        inputs, outputs, programs = common_utils.shard(pred_batch)

        cache = p_init_cache(inputs, outputs, programs)
        predicted = p_pred_step(optimizer.target,
                                inputs,
                                outputs,
                                cache,
                                eos_token,
                                programs.shape[-1],
                                beam_size)
        predicted = train_lib.tohost(predicted)
        inputs, outputs, programs = map(train_lib.tohost,
                                        (inputs, outputs, programs))

        pred_denominator += programs.shape[0]
        for i, beams in enumerate(predicted):
          inps, outs, io_string = decode_io(inputs[i], outputs[i])
          p, p_score = train_lib.eval_predicted(
              beams, inps, outs,
              parse_beam_fn=lambda x: decode_program(x)[0])
          if p_score >= len(inps):
            pred_acc += 1
          ios.append(io_string)
          targets.append(decode_program(programs[i])[1])
          predictions.append(p.to_string() if p else '')

      all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap(
          jax.tree_map(np.array, (pred_acc, pred_denominator)))

      # Record beam search results as text summaries.
      message = []
      for n in np.random.choice(np.arange(len(predictions)), 8):
        text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                f'predicted: {predictions[n]}\n\n')
        message.append(text)

      # Write to tensorboard.
      if jax.host_id() == 0:
        logging.info('Prediction time (beam %d): %.4f s step %d, score %.4f.',
                     beam_size, time.time() - t_inference_start, step,
                     all_pred_acc / all_pred_denominator)
        summary_writer.scalar('predict/score-{}'.format(beam_size),
                              all_pred_acc / all_pred_denominator, step)
        summary_writer.text('samples-{}'.format(beam_size),
                            '\n------\n'.join(message), step)
        summary_writer.flush()
def main(argv):
    del argv  # Unused
    if hasattr(tf, 'enable_v2_behavior'):
        tf.enable_v2_behavior()
    tf.test.main()
Exemple #29
0
def main(argv):
    del argv  # unused arg
    tf.enable_v2_behavior()
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    train_input_fn = utils.load_input_fn(split=tfds.Split.TRAIN,
                                         name=FLAGS.dataset,
                                         batch_size=FLAGS.per_core_batch_size,
                                         use_bfloat16=FLAGS.use_bfloat16)
    clean_test_input_fn = utils.load_input_fn(
        split=tfds.Split.TEST,
        name=FLAGS.dataset,
        batch_size=FLAGS.per_core_batch_size,
        use_bfloat16=FLAGS.use_bfloat16)
    train_dataset = strategy.experimental_distribute_datasets_from_function(
        train_input_fn)
    test_datasets = {
        'clean':
        strategy.experimental_distribute_datasets_from_function(
            clean_test_input_fn),
    }
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar10':
            load_c_input_fn = utils.load_cifar10_c_input_fn
        else:
            load_c_input_fn = functools.partial(utils.load_cifar100_c_input_fn,
                                                path=FLAGS.cifar100_c_path)
        corruption_types, max_intensity = utils.load_corrupted_test_info(
            FLAGS.dataset)
        for corruption in corruption_types:
            for intensity in range(1, max_intensity + 1):
                input_fn = load_c_input_fn(
                    corruption_name=corruption,
                    corruption_intensity=intensity,
                    batch_size=FLAGS.per_core_batch_size,
                    use_bfloat16=FLAGS.use_bfloat16)
                test_datasets['{0}_{1}'.format(corruption, intensity)] = (
                    strategy.experimental_distribute_datasets_from_function(
                        input_fn))

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    steps_per_epoch = ds_info.splits['train'].num_examples // batch_size
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building ResNet model')
        model = wide_resnet(input_shape=ds_info.features['image'].shape,
                            depth=28,
                            width_multiplier=10,
                            num_classes=num_classes,
                            l2=FLAGS.l2,
                            version=2)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(start_epoch * FLAGS.train_epochs) // 200
                           for start_epoch in FLAGS.lr_decay_epochs]
        lr_schedule = utils.LearningRateSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            ed.metrics.ExpectedCalibrationError(num_classes=num_classes,
                                                num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            ed.metrics.ExpectedCalibrationError(num_classes=num_classes,
                                                num_bins=FLAGS.num_bins),
        }
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, max_intensity + 1):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        ed.metrics.ExpectedCalibrationError(
                            num_classes=num_classes, num_bins=FLAGS.num_bins))

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].update_state(labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)
            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))

            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator, dataset_name)
            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types, max_intensity)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
      sequence.append(current_word)

    return sequence


def main(argv):
  del argv

  sentences = ["<S> hello there <E>", "<S> how are you doing today <E>"]
  vocab = [
      "<S>", "<E>", "hello", "there", "how", "are", "you", "doing", "today"
  ]

  module = TextRnnModel(vocab=vocab, emb_dim=10, buckets=100, state_size=128)

  for _ in range(100):
    _ = module.train(tf.constant(sentences))

  # We have to call this function explicitly if we want it exported, because it
  # has no input_signature in the @tf.function decorator.
  decoded = module.decode_greedy(
      sequence_length=10, first_word=tf.constant("<S>"))
  _ = [d.numpy() for d in decoded]

  tf.saved_model.save(module, FLAGS.export_dir)


if __name__ == "__main__":
  tf.enable_v2_behavior()
  app.run(main)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    stan_model = getattr(targets, FLAGS.target)()

    with stan_model.sample_fn(sampling_iters=FLAGS.stan_samples,
                              chains=FLAGS.stan_chains,
                              show_progress=True) as mcmc_output:
        summary = mcmc_output.summary()
        if FLAGS.print_summary:
            pd.set_option('display.max_rows', sys.maxsize)
            pd.set_option('display.max_columns', sys.maxsize)
            print(mcmc_output.diagnose())
            print(summary)

        array_strs = []
        for name, fn in sorted(stan_model.extract_fns.items()):
            transformed_samples = []

            # We handle one chain at a time to reduce memory usage.
            chain_means = []
            chain_stds = []
            chain_esss = []
            for chain_id in range(FLAGS.stan_chains):
                # TODO(https://github.com/stan-dev/cmdstanpy/issues/218): This step is
                # very slow and wastes memory. Consider reading the CSV files ourselves.

                # sample shape is [num_samples, num_chains, num_columns]
                chain = mcmc_output.sample[:, chain_id, :]
                dataframe = pd.DataFrame(chain,
                                         columns=mcmc_output.column_names)

                transformed_samples = fn(dataframe)

                # We reduce over the samples dimension. Transformations can return
                # nested outputs.
                mean = tf.nest.map_structure(lambda s: s.mean(0),
                                             transformed_samples)
                std = tf.nest.map_structure(lambda s: s.std(0),
                                            transformed_samples)
                ess = tf.nest.map_structure(get_ess, transformed_samples)

                chain_means.append(mean)
                chain_stds.append(std)
                chain_esss.append(ess)

            # Now we reduce across chains.
            ess = tf.nest.map_structure(lambda *s: np.sum(s, 0), *chain_esss)
            mean = tf.nest.map_structure(lambda *s: np.mean(s, 0),
                                         *chain_means)
            sem = tf.nest.map_structure(lambda std, ess: std / np.sqrt(ess),
                                        std, ess)
            std = tf.nest.map_structure(lambda *s: np.mean(s, 0), *chain_stds)

            for (tuple_path, mean_part), sem_part, std_part in zip(
                    nest.flatten_with_tuple_paths(mean), tf.nest.flatten(sem),
                    tf.nest.flatten(std)):
                array_strs.extend(
                    ground_truth_encoding.save_ground_truth_part(
                        name=name,
                        tuple_path=tuple_path,
                        mean=mean_part,
                        sem=sem_part,
                        std=std_part,
                        sestd=None,
                    ))

    argv_str = '\n'.join(['  {} \\'.format(arg) for arg in sys.argv[1:]])
    command_str = (
        """bazel run //tools/inference_gym_ground_truth:get_ground_truth -- \
{argv_str}""".format(argv_str=argv_str))

    file_str = ground_truth_encoding.get_ground_truth_module_source(
        target_name=FLAGS.target,
        command_str=command_str,
        array_strs=array_strs)

    if FLAGS.output_directory is None:
        file_basedir = os.path.dirname(os.path.realpath(__file__))
        output_directory = os.path.join(
            file_basedir, '../../spinoffs/inference_gym/targets/ground_truth')
    else:
        output_directory = FLAGS.output_directory
    file_path = os.path.join(output_directory, '{}.py'.format(FLAGS.target))
    print('Writing ground truth values to: {}'.format(file_path))
    with open(file_path, 'w') as f:
        f.write(file_str)