def main(_): tf.enable_v2_behavior() tf.keras.backend.set_learning_phase(1) train()
one_hot_inputs, vocab_size) inv_inputs = tf.argmax(one_hot_inv, axis=-1) inputs_inv_inputs = tf.math.floormod(inputs * inv_inputs, vocab_size) self.assertAllEqual(inputs_inv_inputs, np.ones((batch_size, length))) def testApproximatelyStochastic(self): rng = np.random.RandomState(0) tf.random.set_seed(1) for dims in [2, 5, 10]: for batch_size in [1, 2, 10]: log_alpha = rng.randn(batch_size, dims, dims) result = ed.layers.utils.sinkhorn(log_alpha) self.assertAllClose(np.sum(result, 1), np.tile([1.0], (batch_size, dims)), atol=1e-3) self.assertAllClose(np.sum(result, 2), np.tile([1.0], (batch_size, dims)), atol=1e-3) def testSoftToHardPermutation(self): """The solution of the matching for the identity matrix is range(N).""" dims = 10 identity = tf.eye(dims) result_matching = ed.layers.utils.soft_to_hard_permutation(identity) self.assertAllEqual(result_matching[0], np.eye(dims)) if __name__ == '__main__': tf.enable_v2_behavior() tf.test.main()
def main(_): tf.enable_v2_behavior() ############################################################################## ######################### Data loading and processing ######################## ############################################################################## print('Loading data') with gfile.GFile(_TRANSITION_PATH, 'r') as f: transitions = np.load(f) if np.max(transitions) > 1.0: transitions = transitions / 255.0 with gfile.GFile(_SYNTHETIC_TRANSITION_PATH, 'r') as f: synthetic_tran_sitions = np.load(f) if np.max(synthetic_transitions) > 1.0: synthetic_transitions = synthetic_transitions / 255.0 with gfile.GFile(transition_label_path, 'r') as f: captions = pickle.load(f) with gfile.GFile(_SYNTHETIC_TRANSITION_LABEL_PATH, 'r') as f: synthetic_captions = pickle.load(f) with gfile.GFile(vocab_path, 'r') as f: vocab_list = f.readlines() vocab_list = [w[:-1].decode('utf-8') for w in vocab_list] vocab_list = ['eos', 'sos'] + vocab_list v2i, i2v = wv.create_look_up_table(vocab_list) encode_fn = wv.encode_text_with_lookup_table(v2i) decode_fn = wv.decode_with_lookup_table(i2v) encoded_captions = [] for all_cp in captions: for cp in all_cp: cp = 'sos ' + cp + ' eos' encoded_captions.append(np.array(encode_fn(cp))) synthetic_encoded_captions = [] for all_cp in synthetic_captions: for cp in all_cp: cp = 'sos ' + cp + ' eos' synthetic_encoded_captions.append(np.array(encode_fn(cp))) all_caption_n = len(encoded_captions) all_synthetic_caption_n = len(synthetic_encoded_captions) encoded_captions = np.array(encoded_captions) encoded_captions = pad_to_max_length(encoded_captions, max_l=15) synthetic_encoded_captions = np.array(synthetic_encoded_captions) synthetic_encoded_captions = pad_to_max_length(synthetic_encoded_captions, max_l=15) obs_idx, caption_idx = [], [] curr_caption_idx = 0 for i, _ in enumerate(transitions): for cp in captions[i]: obs_idx.append(i) caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_caption_n synthetic_obs_idx, synthetic_caption_idx = [], [] curr_caption_idx = 0 for i, _ in enumerate(synthetic_transitions): for cp in synthetic_captions[i]: synthetic_obs_idx.append(i) synthetic_caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_synthetic_caption_n obs_idx = np.array(obs_idx) caption_idx = np.array(caption_idx) all_idx = np.arange(len(caption_idx)) train_idx = all_idx[:int(len(all_idx) * 0.8)] test_idx = all_idx[int(len(all_idx) * 0.8):] print('Number of training examples: {}'.format(len(train_idx))) print('Number of test examples: {}\n'.format(len(test_idx))) synthetic_obs_idx = np.array(synthetic_obs_idx) synthetic_caption_idx = np.array(synthetic_caption_idx) synthetic_all_idx = np.arange(len(synthetic_caption_idx)) synthetic_train_idx = synthetic_all_idx[:int(len(synthetic_all_idx) * 0.8)] synthetic_test_idx = synthetic_all_idx[int(len(synthetic_all_idx) * 0.8):] print('Number of synthetic training examples: {}'.format( len(synthetic_train_idx))) print('Number of synthetic test examples: {}\n'.format( len(synthetic_test_idx))) ############################################################################## ############################# Training Setup ################################# ############################################################################## embedding_dim = 32 units = 64 vocab_size = len(vocab_list) batch_size = 64 max_sequence_length = 15 encoder_config = {'name': 'image', 'embedding_dim': 32} decoder_config = { 'name': 'attention', 'word_embedding_dim': 64, 'hidden_units': 256, 'vocab_size': len(vocab_list), } encoder = get_captioning_encoder(encoder_config) decoder = get_captioning_decoder(decoder_config) optimizer = tf.keras.optimizers.Adam() loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction='none') def loss_function(real, pred, sos_symbol=1): mask = tf.math.logical_not(tf.math.equal(real, sos_symbol)) loss_ = loss_object(real, pred) mask = tf.cast(mask, dtype=loss_.dtype) loss_ *= mask return tf.reduce_mean(loss_) @tf.function def train_step(input_tensor, target): """Traing on a batch of data.""" loss = 0 # initializing the hidden state for each batch # because the captions are not related from image to image hidden = decoder.reset_state(batch_size=target.shape[0]) dec_input = tf.expand_dims([1] * target.shape[0], 1) with tf.GradientTape() as tape: features = encoder(input_tensor, training=True) for i in range(1, target.shape[1]): # passing the features through the decoder predictions, hidden, _ = decoder(dec_input, features, hidden, training=True) loss += loss_function(target[:, i], predictions) # using teacher forcing dec_input = tf.expand_dims(target[:, i], 1) total_loss = (loss / int(target.shape[1])) trainable_variables = encoder.trainable_variables + decoder.trainable_variables gradients = tape.gradient(loss, trainable_variables) optimizer.apply_gradients(zip(gradients, trainable_variables)) return loss, total_loss @tf.function def evaluate_batch(input_tensor, target): """Evaluate loss on a batch of data.""" loss = 0 # initializing the hidden state for each batch # because the captions are not related from image to image hidden = decoder.reset_state(batch_size=target.shape[0]) dec_input = tf.expand_dims([1] * target.shape[0], 1) features = encoder(input_tensor, training=False) for i in range(1, target.shape[1]): # passing the features through the decoder predictions, hidden, _ = decoder(dec_input, features, hidden, training=False) loss += loss_function(target[:, i], predictions) # using teacher forcing dec_input = tf.expand_dims(target[:, i], 1) total_loss = (loss / int(target.shape[1])) return total_loss ############################################################################## ############################# Training Loop ################################## ############################################################################## print('Start training...\n') start_epoch = 0 if FLAGS.save_dir: checkpoint_path = FLAGS.save_dir ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) if ckpt_manager.latest_checkpoint: start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1]) epochs = 400 step_per_epoch = int(len(captions) / batch_size) * 10 previous_best = 100. mixing_ratio = 0.4 syn_bs = int(batch_size * 2 * mixing_ratio) true_bs = int(batch_size * 2 * (1 - mixing_ratio)) for epoch in range(start_epoch, epochs): start = time.time() total_loss = 0 for batch in range(step_per_epoch): batch_idx = np.random.choice(train_idx, size=true_bs) synthetic_batch_idx = np.random.choice(synthetic_train_idx, size=syn_bs) input_tensor = transitions[obs_idx[batch_idx], :] synthetic_input_tensor = synthetic_transitions[ synthetic_obs_idx[synthetic_batch_idx], :] input_tensor = np.concatenate( [input_tensor, synthetic_input_tensor], axis=0) input_tensor = encoder.preprocess(input_tensor) target = encoded_captions[caption_idx[batch_idx]] sythetic_target = synthetic_encoded_captions[ synthetic_caption_idx[synthetic_batch_idx]] target = np.concatenate([target, sythetic_target], axis=0) batch_loss, t_loss = train_step(input_tensor, target) total_loss += t_loss if batch % 100 == 0: print('Epoch {} Batch {} Loss {:.4f}'.format( epoch + 1, batch, batch_loss.numpy() / int(target.shape[1]))) if epoch % 5 == 0 and FLAGS.save_dir: test_total_loss = 0 for batch in range(3): batch_idx = np.clip( np.arange(true_bs) + batch * true_bs, 0, 196) idx = test_idx[batch_idx] input_tensor = transitions[obs_idx[idx], :] target = encoded_captions[caption_idx[idx]] t_loss = evaluate_batch(input_tensor, target) test_total_loss += t_loss batch_idx = np.arange(syn_bs) + batch * syn_bs idx = synthetic_test_idx[batch_idx] input_tensor = synthetic_transitions[synthetic_obs_idx[idx], :] target = synthetic_encoded_captions[synthetic_caption_idx[idx]] t_loss = evaluate_batch(input_tensor, target) test_total_loss += t_loss test_total_loss /= 6. if test_total_loss < previous_best: previous_best = test_total_loss ckpt_manager.save(checkpoint_number=epoch) print('Epoch {} | Loss {:.6f} | Val loss {:.6f}'.format( epoch + 1, total_loss / step_per_epoch, previous_best)) print('Time taken for 1 epoch {:.6f} sec\n'.format(time.time() - start)) if epoch % 20 == 0: total_loss = 0 for batch in range(len(test_idx) // batch_size): batch_idx = np.arange(batch_size) + batch * batch_size idx = test_idx[batch_idx] input_tensor = transitions[obs_idx[idx], :] target = encoded_captions[caption_idx[idx]] # input_tensor = input_tensor[:, 0] - input_tensor[:, 1] t_loss = evaluate_batch(input_tensor, target) total_loss += t_loss print('====================================================') print('Test Loss {:.6f}'.format(total_loss / (len(test_idx) // batch_size))) print('====================================================\n')
def main(_): tf.enable_v2_behavior() visualize_tfrecords(FLAGS.path_to_tfrecord, FLAGS.num_vids, FLAGS.num_skip_frames)
def main(argv): del argv # unused arg tf.enable_v2_behavior() dataset_train, ds_info = utils.load_dataset(tfds.Split.TRAIN, with_info=True) dataset_test = utils.load_dataset(tfds.Split.TEST) dataset_train = dataset_train.batch(FLAGS.batch_size) dataset_test = dataset_test.batch(FLAGS.batch_size) model = deterministic.resnet_v1( input_shape=ds_info.features['image'].shape, depth=20, num_classes=ds_info.features['label'].num_classes, l2=0.) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.output_dir, '**/*.ckpt.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) # Collect the logits output for each ensemble member and train/test data # point. We also collect the labels. # TODO(trandustin): Refactor data loader so you can get the full dataset in # memory without looping. logits_train = [] logits_test = [] labels_train = [] labels_test = [] for m, ensemble_filename in enumerate(ensemble_filenames): model.load_weights(ensemble_filename) logits = [] for features, labels in dataset_train: logits.append(model(features, training=False)) if m == 0: labels_train.append(labels) logits = tf.concat(logits, axis=0) logits_train.append(logits) if m == 0: labels_train = tf.concat(labels_train, axis=0) logits = [] for features, labels in dataset_test: logits.append(model(features, training=False)) if m == 0: labels_test.append(labels) logits = tf.concat(logits, axis=0) logits_test.append(logits) if m == 0: labels_test = tf.concat(labels_test, axis=0) logging.info('Predictions completed for checkpoint %s', ensemble_filename) metrics = {} # Compute the ensemble's NLL and Gibbs cross entropy for each data point. # Then average over the dataset. nll_train = ensemble_negative_log_likelihood(labels_train, logits_train) nll_test = ensemble_negative_log_likelihood(labels_test, logits_test) gibbs_ce_train = gibbs_cross_entropy(labels_train, logits_train) gibbs_ce_test = gibbs_cross_entropy(labels_test, logits_test) metrics['train_nll'] = tf.reduce_mean(nll_train) metrics['test_nll'] = tf.reduce_mean(nll_test) metrics['train_gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_train) metrics['test_gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test) # Given the per-element logits tensor of shape [ensemble_size, dataset_size, # num_classes], average over the ensemble members' probabilities. Then # compute accuracy and average over the dataset. probs_train = tf.reduce_mean(tf.nn.softmax(logits_train), axis=0) probs_test = tf.reduce_mean(tf.nn.softmax(logits_test), axis=0) accuracy_train = tf.keras.metrics.sparse_categorical_accuracy( labels_train, probs_train) accuracy_test = tf.keras.metrics.sparse_categorical_accuracy( labels_test, probs_test) metrics['train_accuracy'] = tf.reduce_mean(accuracy_train) metrics['test_accuracy'] = tf.reduce_mean(accuracy_test) logging.info('Metrics: %s', metrics)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() config = FLAGS.config logging.info('===========Config Dict============') logging.info(config) batch_size = config.batch_size learning_rate = config.learning_rate num_train_steps = config.num_train_steps num_eval_steps = config.num_eval_steps eval_freq = config.eval_frequency random_seed = config.random_seed model_type = config.model_type max_length = config.max_length if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'summary')) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_matching_datasets( n_devices=jax.local_device_count(), task_name=FLAGS.task_name, data_dir=FLAGS.data_dir, batch_size=batch_size, fixed_vocab=None, max_length=max_length, tokenizer=config.tokenizer, vocab_file_path=FLAGS.vocab_file_path) vocab_size = encoder.vocab_size logging.info('Vocab Size: %d', vocab_size) train_ds = train_ds.repeat() train_iter = iter(train_ds) input_shape = (batch_size, max_length) model_kwargs = { 'vocab_size': vocab_size, 'emb_dim': config.emb_dim, 'num_heads': config.num_heads, 'num_layers': config.num_layers, 'qkv_dim': config.qkv_dim, 'mlp_dim': config.mlp_dim, 'max_len': max_length, 'classifier': True, 'num_classes': 2, 'classifier_pool': config.pooling_mode } rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) if model_type == 'transformer': model = create_model(init_rng, transformer.TransformerDualEncoder, input_shape, input_shape, model_kwargs) else: raise ValueError('Model type not supported.') optimizer = create_optimizer(model, learning_rate, weight_decay=FLAGS.config.weight_decay) del model # Don't keep a copy of the initial model. start_step = 0 if config.restore_checkpoints or FLAGS.test_only: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors=config.factors, base_learning_rate=learning_rate, warmup_steps=config.warmup) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # p_pred_step = jax.pmap(predict_step, axis_name='batch') def run_eval(eval_ds, num_eval_steps=-1): eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.count() else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']), a_max=1.0e4) return eval_summary if FLAGS.test_only: with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.json'), 'w') as f: test_summary = run_eval(test_ds) json.dump(jax.tree_map(lambda x: x.tolist(), test_summary), f) return metrics_all = [] tick = time.time() logging.info('Starting training') logging.info('====================') for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access # logging.info(batch) optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) logging.info('train in step: %d', step) # Save a Checkpoint if ((step % config.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar(f'train_{key}', val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_summary = run_eval(eval_ds, num_eval_steps) logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if jax.host_id() == 0: for key, val in eval_summary.items(): summary_writer.scalar(f'eval_{key}', val, step) summary_writer.flush() # Test eval # Eval Metrics logging.info('Testing...') test_summary = run_eval(test_ds, num_eval_steps) logging.info('test in step: %d, loss: %.4f, acc: %.4f', step, test_summary['loss'], test_summary['accuracy']) if jax.host_id() == 0: for key, val in test_summary.items(): summary_writer.scalar(f'test_{key}', val, step) summary_writer.flush()
def main(argv): del argv # unused arg if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) dataset_test = utils.ImageNetInput(is_training=False, data_dir=FLAGS.data_dir, batch_size=FLAGS.per_core_batch_size, use_bfloat16=False).input_fn() test_datasets = {'clean': dataset_test} model = deterministic_model.resnet50(input_shape=(224, 224, 3), num_classes=NUM_CLASSES) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob( os.path.join(FLAGS.output_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Collect the logits output for each ensemble member and test data # point. We also collect the labels. logits_test = {'clean': []} labels_test = {'clean': []} corruption_types, max_intensity = utils.load_corrupted_test_info() for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) logits_test[dataset_name] = [] labels_test[dataset_name] = [] test_datasets[dataset_name] = utils.load_corrupted_test_dataset( name=name, intensity=intensity, batch_size=FLAGS.per_core_batch_size, drop_remainder=True, use_bfloat16=False) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) logging.info('Working on test data for ensemble member %s', m) for name, test_dataset in test_datasets.items(): logits = [] for features, labels in test_dataset: logits.append(model(features, training=False)) if m == 0: labels_test[name].append(labels) logits = tf.concat(logits, axis=0) logits_test[name].append(logits) if m == 0: labels_test[name] = tf.concat(labels_test[name], axis=0) logging.info('Finished testing on %s', format(name)) metrics = { 'test/ece': ed.metrics.ExpectedCalibrationError(num_classes=NUM_CLASSES, num_bins=15) } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/ece_{}'.format( name)] = ed.metrics.ExpectedCalibrationError( num_classes=NUM_CLASSES, num_bins=15) corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format( name)] = tf.keras.metrics.Mean() for name, test_dataset in test_datasets.items(): labels = labels_test[name] logits = logits_test[name] nll_test = ensemble_negative_log_likelihood(labels, logits) gibbs_ce_test = gibbs_cross_entropy(labels_test[name], logits_test[name]) labels = tf.cast(labels, tf.int32) logits = tf.convert_to_tensor(logits) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) accuracy = tf.keras.metrics.sparse_categorical_accuracy(labels, probs) if name == 'clean': metrics['test/negative_log_likelihood'] = tf.reduce_mean(nll_test) metrics['test/gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test) metrics['test/accuracy'] = tf.reduce_mean(accuracy) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( tf.reduce_mean(nll_test)) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( tf.reduce_mean(accuracy)) corrupt_metrics['test/ece_{}'.format(name)].update_state( labels, probs) corrupt_results = {} corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity) metrics['test/ece'] = metrics['test/ece'].result() total_results = {name: metric for name, metric in metrics.items()} total_results.update(corrupt_results) logging.info('Metrics: %s', total_results)
def main(argv): del argv # unused arg if not FLAGS.use_gpu: raise ValueError('Only GPU is currently supported.') if FLAGS.num_cores > 1: raise ValueError('Only a single accelerator is currently supported.') tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) tf.io.gfile.makedirs(FLAGS.output_dir) batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size dataset_test = utils.ImageNetInput( is_training=False, data_dir=FLAGS.data_dir, batch_size=FLAGS.per_core_batch_size, use_bfloat16=False).input_fn() test_datasets = {'clean': dataset_test} corruption_types, max_intensity = utils.load_corrupted_test_info() for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) test_datasets[dataset_name] = utils.load_corrupted_test_dataset( name=name, intensity=intensity, batch_size=FLAGS.per_core_batch_size, drop_remainder=True, use_bfloat16=False) model = deterministic_model.resnet50(input_shape=(224, 224, 3), num_classes=NUM_CLASSES) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Search for checkpoints from their index file; then remove the index suffix. ensemble_filenames = tf.io.gfile.glob(os.path.join(FLAGS.checkpoint_dir, '**/*.index')) ensemble_filenames = [filename[:-6] for filename in ensemble_filenames] ensemble_size = len(ensemble_filenames) logging.info('Ensemble size: %s', ensemble_size) logging.info('Ensemble number of weights: %s', ensemble_size * model.count_params()) logging.info('Ensemble filenames: %s', str(ensemble_filenames)) checkpoint = tf.train.Checkpoint(model=model) # Write model predictions to files. num_datasets = len(test_datasets) for m, ensemble_filename in enumerate(ensemble_filenames): checkpoint.restore(ensemble_filename) for n, (name, test_dataset) in enumerate(test_datasets.items()): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) if not tf.io.gfile.exists(filename): logits = [] test_iterator = iter(test_dataset) for _ in range(steps_per_eval): features, _ = next(test_iterator) # pytype: disable=attribute-error logits.append(model(features, training=False)) logits = tf.concat(logits, axis=0) with tf.io.gfile.GFile(filename, 'w') as f: np.save(f, logits.numpy()) percent = (m * num_datasets + (n + 1)) / (ensemble_size * num_datasets) message = ('{:.1%} completion for prediction: ensemble member {:d}/{:d}. ' 'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size, n + 1, num_datasets)) logging.info(message) metrics = { 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/gibbs_cross_entropy': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), } corrupt_metrics = {} for name in test_datasets: corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean() corrupt_metrics['test/accuracy_{}'.format(name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format( name)] = ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins) # Evaluate model predictions. for n, (name, test_dataset) in enumerate(test_datasets.items()): logits_dataset = [] for m in range(ensemble_size): filename = '{dataset}_{member}.npy'.format(dataset=name, member=m) filename = os.path.join(FLAGS.output_dir, filename) with tf.io.gfile.GFile(filename, 'rb') as f: logits_dataset.append(np.load(f)) logits_dataset = tf.convert_to_tensor(logits_dataset) test_iterator = iter(test_dataset) for step in range(steps_per_eval): _, labels = next(test_iterator) # pytype: disable=attribute-error logits = logits_dataset[:, (step*batch_size):((step+1)*batch_size)] labels = tf.cast(tf.reshape(labels, [-1]), tf.int32) negative_log_likelihood = tf.reduce_mean( ensemble_negative_log_likelihood(labels, logits)) per_probs = tf.nn.softmax(logits) probs = tf.reduce_mean(per_probs, axis=0) if name == 'clean': gibbs_ce = tf.reduce_mean(gibbs_cross_entropy(labels, logits)) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format(name)].update_state( negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format(name)].update_state( labels, probs) corrupt_metrics['test/ece_{}'.format(name)].update_state( labels, probs) message = ('{:.1%} completion for evaluation: dataset {:d}/{:d}'.format( (n + 1) / num_datasets, n + 1, num_datasets)) logging.info(message) corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, corruption_types, max_intensity, FLAGS.alexnet_errors_path) total_results = {name: metric.result() for name, metric in metrics.items()} total_results.update(corrupt_results) logging.info('Metrics: %s', total_results)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # This seems to be necessary even when importing TF2? tf.enable_v2_behavior() # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) # call a jitted initialization function to get the initial parameter tree @jax.jit def initialize_variables(rng): return models.Transformer(eval_config).init( rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) initial_variables = initialize_variables(init_rng) # apply an optimizer to this tree optimizer_def = optim.Adam(FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['param']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0, )) p_eval_step = jax.pmap(functools.partial( eval_step, config=eval_config, label_smoothing=FLAGS.label_smoothing), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info('Starting training loop.') metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % FLAGS.eval_frequency != 0 and step > 0: continue # Training Metrics logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info('train in step: %d, loss: %.4f', step, summary['loss']) # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) logging.info('eval time: %.4f s step %d', time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info('Translating evaluation dataset.') t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch['inputs']) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_id, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) logging.info('Translation time: %.4f s step %d.', time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush() logging.info('Translation BLEU Score %.4f', bleu_score)
def main(argv): del argv # unused arg tf.enable_v2_behavior() tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) def train_input_fn(ctx): """Sets up local (per-core) dataset batching.""" dataset = utils.load_distributed_dataset( split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size // FLAGS.num_models, drop_remainder=True, use_bfloat16=FLAGS.use_bfloat16, proportion=FLAGS.train_proportion) if ctx and ctx.num_input_pipelines > 1: dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) return dataset # No matter what percentage of training proportion, we still evaluate the # model on the full test dataset. def test_input_fn(ctx): """Sets up local (per-core) dataset batching.""" dataset = utils.load_distributed_dataset( split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size // FLAGS.num_models, drop_remainder=True, use_bfloat16=FLAGS.use_bfloat16) if ctx and ctx.num_input_pipelines > 1: dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) return dataset train_dataset = strategy.experimental_distribute_datasets_from_function( train_input_fn) test_dataset = strategy.experimental_distribute_datasets_from_function( test_input_fn) ds_info = tfds.builder(FLAGS.dataset).info batch_size = ((FLAGS.per_core_batch_size // FLAGS.num_models) * FLAGS.num_cores) # Train_proportion is a float so need to convert steps_per_epoch to int. steps_per_epoch = int( (ds_info.splits['train'].num_examples * FLAGS.train_proportion) // batch_size) steps_per_eval = ds_info.splits['test'].num_examples // batch_size if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) with strategy.scope(): logging.info('Building Keras ResNet-32 model') model = batchensemble_model.ensemble_resnet_v1( input_shape=ds_info.features['image'].shape, depth=32, num_classes=ds_info.features['label'].num_classes, width_multiplier=4, num_models=FLAGS.num_models, random_sign_init=FLAGS.random_sign_init, dropout_rate=FLAGS.dropout_rate, l2=FLAGS.l2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) base_lr = FLAGS.base_learning_rate * batch_size / 128 lr_schedule = utils.ResnetLearningRateSchedule(steps_per_epoch, base_lr, _LR_SCHEDULE) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True) train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) train_nll = tf.keras.metrics.Mean('train_nll', dtype=tf.float32) train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'train_accuracy', dtype=tf.float32) test_nll = tf.keras.metrics.Mean('test_nll', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) test_nlls = [] test_accs = [] for i in range(FLAGS.num_models): test_nlls.append( tf.keras.metrics.Mean('test_nll_{}'.format(i), dtype=tf.float32)) test_accs.append( tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy_{}'.format(i), dtype=tf.float32)) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries/')) @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.version2: images = tf.tile(images, [FLAGS.num_models, 1, 1, 1]) labels = tf.tile(labels, [FLAGS.num_models]) with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): # Apply different learning rate on the fast weight approximate # posterior/prior parameters. This is excludes BN and slow weights, # but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grads_and_vars.append( (grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) train_loss.update_state(loss) train_nll.update_state(negative_log_likelihood) train_accuracy.update_state(labels, logits) strategy.experimental_run_v2(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs images = tf.tile(images, [FLAGS.num_models, 1, 1, 1]) logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=FLAGS.num_models, axis=0) for i in range(FLAGS.num_models): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) test_nlls[i].update_state(member_loss) test_accs[i].update_state(labels, member_probs) probs = tf.reduce_mean(per_probs, axis=0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) test_nll.update_state(negative_log_likelihood) test_accuracy.update_state(labels, probs) strategy.experimental_run_v2(step_fn, args=(next(iterator), )) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) with summary_writer.as_default(): for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ( '{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) tf.summary.scalar('train/loss', train_loss.result(), step=epoch + 1) tf.summary.scalar('train/negative_log_likelihood', train_nll.result(), step=epoch + 1) tf.summary.scalar('train/accuracy', train_accuracy.result(), step=epoch + 1) logging.info('Train Loss: %s, Accuracy: %s%%', round(float(train_loss.result()), 4), round(float(train_accuracy.result() * 100), 2)) train_loss.reset_states() train_nll.reset_states() train_accuracy.reset_states() test_iterator = iter(test_dataset) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator) tf.summary.scalar('test/negative_log_likelihood', test_nll.result(), step=epoch + 1) tf.summary.scalar('test/accuracy', test_accuracy.result(), step=epoch + 1) logging.info('Test NLL: %s, Accuracy: %s%%', round(float(test_nll.result()), 4), round(float(test_accuracy.result() * 100), 2)) test_nll.reset_states() test_accuracy.reset_states() for i in range(FLAGS.num_models): tf.summary.scalar('test/ensemble_nll_member{}'.format(i), test_nlls[i].result(), step=epoch + 1) tf.summary.scalar('test/ensemble_accuracy_member{}'.format(i), test_accs[i].result(), step=epoch + 1) logging.info('Member %d Test loss: %s, accuracy: %s%%', i, round(float(test_nlls[i].result()), 4), round(float(test_accs[i].result() * 100), 2)) test_nlls[i].reset_states() test_accs[i].reset_states() if (epoch + 1) % 20 == 0: checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name)
def main(unused_argv): tf.enable_v2_behavior() num_workers = 1 job_name = 'worker' primary_cpu_task = '/job:%s' % job_name is_tpu_pod = num_workers > 1 model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR batch_size = PER_CORE_BATCH_SIZE * FLAGS.num_cores steps_per_epoch = FLAGS.steps_per_epoch or (int( APPROX_IMAGENET_TRAINING_IMAGES // batch_size)) steps_per_eval = int(1.0 * math.ceil(IMAGENET_VALIDATION_IMAGES / batch_size)) logging.info('Saving checkpoints at %s', model_dir) logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, job_name=job_name) tf.config.experimental_connect_to_host(resolver.master()) # pylint: disable=line-too-long tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) with tf.device(primary_cpu_task): # TODO(b/130307853): In TPU Pod, we have to use # `strategy.experimental_distribute_datasets_from_function` instead of # `strategy.experimental_distribute_dataset` because dataset cannot be # cloned in eager mode. And when using # `strategy.experimental_distribute_datasets_from_function`, we should use # per core batch size instead of global batch size, because no re-batch is # happening in this case. if is_tpu_pod: imagenet_train = imagenet_input.ImageNetInput( is_training=True, data_dir=FLAGS.data, batch_size=PER_CORE_BATCH_SIZE, use_bfloat16=_USE_BFLOAT16) imagenet_eval = imagenet_input.ImageNetInput( is_training=False, data_dir=FLAGS.data, batch_size=PER_CORE_BATCH_SIZE, use_bfloat16=_USE_BFLOAT16) train_dataset = strategy.experimental_distribute_datasets_from_function( imagenet_train.input_fn) test_dataset = strategy.experimental_distribute_datasets_from_function( imagenet_eval.input_fn) else: imagenet_train = imagenet_input.ImageNetInput( is_training=True, data_dir=FLAGS.data, batch_size=batch_size, use_bfloat16=_USE_BFLOAT16) imagenet_eval = imagenet_input.ImageNetInput( is_training=False, data_dir=FLAGS.data, batch_size=batch_size, use_bfloat16=_USE_BFLOAT16) train_dataset = strategy.experimental_distribute_dataset( imagenet_train.input_fn()) test_dataset = strategy.experimental_distribute_dataset( imagenet_eval.input_fn()) with strategy.scope(): logging.info('Building Keras ResNet-50 model') model = resnet_model.ResNet50(num_classes=NUM_CLASSES) optimizer = tf.keras.optimizers.SGD( learning_rate=_BASE_LEARNING_RATE, momentum=0.9, nesterov=True) training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32) training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'training_accuracy', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) logging.info('Finished building Keras ResNet-50 model') checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(model_dir) initial_epoch = 0 if latest_checkpoint: checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch # Create summary writers train_summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'summaries/train')) test_summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'summaries/test')) @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) # Loss calculations. # # Part 1: Prediction loss. prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss1 = tf.reduce_mean(prediction_loss) # Part 2: Model weights regularization loss2 = tf.reduce_sum(model.losses) # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = loss1 + loss2 loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) training_loss.update_state(loss) training_accuracy.update_state(labels, logits) strategy.experimental_run_v2(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync test_loss.update_state(loss) test_accuracy.update_state(labels, logits) strategy.experimental_run_v2(step_fn, args=(next(iterator), )) train_iterator = iter(train_dataset) for epoch in range(initial_epoch, FLAGS.num_epochs): logging.info('Starting to run epoch: %s', epoch) with train_summary_writer.as_default(): for step in range(steps_per_epoch): learning_rate = compute_learning_rate(epoch + 1 + (float(step) / steps_per_epoch)) optimizer.lr = learning_rate if step % 20 == 0: logging.info( 'Learning rate at step %s in epoch %s is %s', step, epoch, optimizer.lr.numpy()) train_step(train_iterator) tf.summary.scalar('loss', training_loss.result(), step=optimizer.iterations) tf.summary.scalar('accuracy', training_accuracy.result(), step=optimizer.iterations) logging.info('Training loss: %s, accuracy: %s%%', round(training_loss.result(), 4), round(training_accuracy.result() * 100, 2)) training_loss.reset_states() training_accuracy.reset_states() with test_summary_writer.as_default(): test_iterator = iter(test_dataset) for step in range(steps_per_eval): if step % 20 == 0: logging.info( 'Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator) tf.summary.scalar('loss', test_loss.result(), step=optimizer.iterations) tf.summary.scalar('accuracy', test_accuracy.result(), step=optimizer.iterations) logging.info('Test loss: %s, accuracy: %s%%', round(test_loss.result(), 4), round(test_accuracy.result() * 100, 2)) test_loss.reset_states() test_accuracy.reset_states() checkpoint_name = checkpoint.save( os.path.join(model_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) filename = os.path.join(FLAGS.expert_dir, FLAGS.env_name + '.npz') (expert_states, expert_actions, expert_next_states, expert_dones) = data_utils.load_expert_data(filename) (expert_states, expert_actions, expert_next_states, expert_dones) = data_utils.subsample_trajectories(expert_states, expert_actions, expert_next_states, expert_dones, FLAGS.num_trajectories) print('# of demonstraions: {}'.format(expert_states.shape[0])) if FLAGS.normalize_states: shift = -np.mean(expert_states, 0) scale = 1.0 / (np.std(expert_states, 0) + 1e-3) expert_states = (expert_states + shift) * scale expert_next_states = (expert_next_states + shift) * scale else: shift = None scale = None env = wrappers.create_il_env(FLAGS.env_name, FLAGS.seed, shift, scale) eval_env = wrappers.create_il_env(FLAGS.env_name, FLAGS.seed + 1, shift, scale) unwrap_env = env while hasattr(unwrap_env, 'env'): if isinstance(unwrap_env, wrappers.NormalizeBoxActionWrapper): expert_actions = unwrap_env.reverse_action(expert_actions) break unwrap_env = unwrap_env.env (expert_states, expert_actions, expert_next_states, expert_dones) = data_utils.add_absorbing_states(expert_states, expert_actions, expert_next_states, expert_dones, env) spec = ( tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32, 'observation'), tensor_spec.TensorSpec([env.action_space.shape[0]], tf.float32, 'action'), tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32, 'next_observation'), tensor_spec.TensorSpec([1], tf.float32, 'reward'), tensor_spec.TensorSpec([1], tf.float32, 'mask'), ) # We need to store at most twice more transition due to # an extra absorbing to itself transition. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=1, max_length=FLAGS.max_timesteps * 2) for i in range(expert_states.shape[0]): # Overwrite rewards for safety. We still have to add them to the replay # buffer to maintain the same interface. Also always use a zero mask # since we need to always bootstrap for imitation learning. add_samples_to_replay_buffer(replay_buffer, expert_states[i], expert_actions[i], expert_next_states[i]) replay_buffer_iter = iter( replay_buffer.as_dataset(sample_batch_size=FLAGS.sample_batch_size)) policy_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=1, max_length=FLAGS.max_timesteps * 2) policy_replay_buffer_iter = iter( policy_replay_buffer.as_dataset( sample_batch_size=FLAGS.sample_batch_size)) expert_states = tf.Variable(expert_states, dtype=tf.float32) expert_actions = tf.Variable(expert_actions, dtype=tf.float32) expert_next_states = tf.Variable(expert_next_states, dtype=tf.float32) expert_dones = tf.Variable(expert_dones, dtype=tf.float32) expert_dataset = tf.data.Dataset.from_tensor_slices( (expert_states, expert_actions, expert_next_states)) expert_dataset = expert_dataset.repeat().shuffle( expert_states.shape[0]).batch( FLAGS.sample_batch_size, drop_remainder=True) expert_dataset_iter = iter(expert_dataset) hparam_str_dict = dict( seed=FLAGS.seed, algo=FLAGS.algo, env_name=FLAGS.env_name) hparam_str = ','.join(['%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys())]) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) log_dir = os.path.join(FLAGS.save_dir, 'logs') log_filename = os.path.join(log_dir, hparam_str) if not os.path.exists(log_dir): os.makedirs(log_dir) if 'dac' in FLAGS.algo: imitator = gail.RatioGANGP(env.observation_space.shape[0], env.action_space.shape[0], FLAGS.log_interval) elif 'value_dice' in FLAGS.algo: imitator = value_dice.ValueDICE( env.observation_space.shape[0], env.action_space.shape[0], nu_lr=FLAGS.nu_lr, actor_lr=FLAGS.actor_lr, alpha_init=FLAGS.sac_alpha, hidden_size=FLAGS.hidden_size, log_interval=FLAGS.log_interval) def get_imitation_learning_rewards(states, actions, _): return imitator.get_log_occupancy_ratio(states, actions) if 'value_dice' in FLAGS.algo: sac = imitator else: sac = twin_sac.SAC( env.observation_space.shape[0], env.action_space.shape[0], FLAGS.log_interval, actor_lr=FLAGS.actor_lr, critic_lr=FLAGS.critic_lr, learn_alpha=FLAGS.learn_alpha, alpha_init=FLAGS.sac_alpha, rewards_fn=get_imitation_learning_rewards) episode_return = 0 episode_timesteps = 0 done = True total_timesteps = 0 previous_time = time.time() eval_returns = [] with tqdm(total=FLAGS.max_timesteps, desc='') as pbar: while total_timesteps < FLAGS.max_timesteps: _update_pbar_msg(pbar, total_timesteps) if total_timesteps % FLAGS.eval_interval == 0: logging.info('Performing policy eval.') average_returns, evaluation_timesteps = evaluate(sac.actor, eval_env) eval_returns.append(average_returns) np.save(log_filename, np.array(eval_returns)) with summary_writer.as_default(): tf.summary.scalar( 'eval gym/average returns', average_returns, step=total_timesteps) with summary_writer.as_default(): tf.summary.scalar( 'eval gym/average episode length', evaluation_timesteps, step=total_timesteps) logging.info('Eval: ave returns=%f, ave episode length=%f', average_returns, evaluation_timesteps) if done: if episode_timesteps > 0: current_time = time.time() with summary_writer.as_default(): tf.summary.scalar( 'train gym/returns', episode_return, step=total_timesteps) tf.summary.scalar( 'train gym/FPS', episode_timesteps / (current_time - previous_time), step=total_timesteps) obs = env.reset() episode_return = 0 episode_timesteps = 0 previous_time = time.time() if total_timesteps < FLAGS.num_random_actions: action = env.action_space.sample() else: if 'dac' in FLAGS.algo: _, sampled_action, _ = sac.actor(np.array([obs])) action = sampled_action[0].numpy() else: mean_action, _, _ = sac.actor(np.array([obs])) action = mean_action[0].numpy() action = (action + np.random.normal( 0, 0.1, size=action.shape)).clip(-1, 1) next_obs, reward, done, _ = env.step(action) # done caused by episode truncation. truncated_done = done and episode_timesteps + 1 == env._max_episode_steps # pylint: disable=protected-access if done and not truncated_done: next_obs = env.get_absorbing_state() # Overwrite rewards for safety. We still have to add them to the replay # buffer to maintain the same interface. Also always use a zero mask # since we need to always bootstrap for imitation learning. add_samples_to_replay_buffer(replay_buffer, obs, action, next_obs) add_samples_to_replay_buffer(policy_replay_buffer, obs, action, next_obs) if done and not truncated_done: # Add several absobrsing states to absorbing states transitions. for abs_i in range(FLAGS.absorbing_per_episode): if abs_i + episode_timesteps < env._max_episode_steps: # pylint: disable=protected-access obs = env.get_absorbing_state() action = env.action_space.sample() next_obs = env.get_absorbing_state() add_samples_to_replay_buffer(replay_buffer, obs, action, next_obs) add_samples_to_replay_buffer(policy_replay_buffer, obs, action, next_obs) episode_return += reward episode_timesteps += 1 total_timesteps += 1 pbar.update(1) obs = next_obs if total_timesteps >= FLAGS.start_training_timesteps: with summary_writer.as_default(): for _ in range(FLAGS.updates_per_step): if 'dac' in FLAGS.algo: imitator.update(expert_dataset_iter, policy_replay_buffer_iter) elif 'value_dice' in FLAGS.algo: imitator.update( expert_dataset_iter, policy_replay_buffer_iter, FLAGS.discount, replay_regularization=FLAGS.replay_regularization) if 'bc' in FLAGS.algo: sac.train_bc(expert_dataset_iter) elif 'dac' in FLAGS.algo: sac.train( replay_buffer_iter, discount=FLAGS.discount, tau=FLAGS.tau, target_entropy=-env.action_space.shape[0], actor_update_freq=FLAGS.actor_update_freq)
def main(argv): global BLEU_THRESHOLD_REACHED if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') init_mllogger() mllogger.event('cache_clear') mllogger.start('init_start') mllogger.event('submission_org', 'Google') mllogger.event('submission_platform', 'TPUv3-{}'.format(jax.device_count())) mllogger.event('submission_division', 'closed') mllogger.event('submission_status', 'research') mllogger.event('submission_benchmark', 'transformer') mllogger.event('train_samples', input_pipeline.N_TRAIN) mllogger.event('eval_samples', input_pipeline.N_EVAL) tf.enable_v2_behavior() # Use hardware RNG for bernoulli randoms in dropout mask creation. if FLAGS.hardware_rng: models.set_hardware_bernoulli() num_partitions = FLAGS.num_partitions batch_size = FLAGS.batch_size if batch_size is None: batch_size = min(16 * jax.device_count() // num_partitions, 2048) mllogger.event('global_batch_size', batch_size) num_eval_steps = FLAGS.num_eval_steps max_target_length = FLAGS.max_target_length max_eval_target_length = FLAGS.max_eval_target_length max_length = max(max_target_length, max_eval_target_length) mllogger.event('max_sequence_length', max_length, metadata={'method': 'discard'}) if FLAGS.random_seed is not None: seed = FLAGS.random_seed else: seed = np.int32(time.time() if jax.host_id() == 0 else 0) seed = per_host_sum_pmap(seed) mllogger.event('seed', int(seed)) steps_per_epoch = int(math.ceil(input_pipeline.N_TRAIN / batch_size)) logging.info('steps per epoch: %d', steps_per_epoch) num_replicas = jax.local_device_count() // num_partitions device_train_input_shape = (batch_size // (num_replicas * jax.host_count()), max_target_length) # This is per-host; in principle 64/replica or more should fit eval_batch_size = min( 32 * num_replicas, int( math.ceil(input_pipeline.N_EVAL / (num_replicas * jax.host_count()))) * num_replicas) logging.info('eval batch size: %d', eval_batch_size) pred_batches = int( math.ceil(input_pipeline.N_EVAL / (jax.host_count() * eval_batch_size))) logging.info('pred batches: %d', pred_batches) broadcast = functools.partial(_broadcast, num_replicas=num_replicas, num_partitions=num_partitions) if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) else: train_summary_writer = None eval_summary_writer = None # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') if FLAGS.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') def maybe_start_xprof(seconds): if jax.host_id() == 0 and FLAGS.xprof: xprof = xprof_session.XprofSession() xprof.start_session('REDACTED', True, 2) def sleep_and_end_xprof(): time.sleep(seconds) logging.info( 'Xprof URL: %s', xprof.end_session_and_get_url( tag= 'flax transformer, {} devices, {}-way, batch {} per replica' .format(jax.device_count(), num_partitions, device_train_input_shape[0]))) thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof) # MLPerf 2020 WMT en-de dataset uses a custom T2T dataset: # Shared 32K subword tokenization # 256-length packed training examples from WMT17 # 97-length unpacked evaluation examples from WMT14 train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation' ] encoder = mlperf_encoder.SubwordTextEncoder(filename=FLAGS.vocab_path) input_encoder = encoder target_encoder = encoder vocab_size = input_encoder.vocab_size output_vocab_size = target_encoder.vocab_size input_shape = (batch_size, max_target_length) target_shape = (batch_size, max_target_length) transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': output_vocab_size, 'emb_dim': 1024, 'num_heads': 16, 'num_layers': 6, 'qkv_dim': 1024, 'mlp_dim': 4096, 'max_len': max_length, 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, 'num_partitions': num_partitions, } rng = random.PRNGKey(seed) rng, init_rng = random.split(rng) model, cache_def = create_model(init_rng, tuple(input_shape), tuple(target_shape), transformer_kwargs) mllogger.event('opt_name', 'adam') if batch_size < 1024: learning_rate = 4.0 # 0.0625 warmup_steps = 1000 beta1 = 0.9 beta2 = 0.98 if batch_size < 2048: learning_rate = 2.0 warmup_steps = 500 # ?? beta1 = 0.9 # ?? beta2 = 0.98 # ?? else: learning_rate = 3.3092157691415953 warmup_steps = 664 beta1 = 0.9086575725261137 beta2 = 0.9198719118104947 epsilon = 1e-9 if FLAGS.learning_rate is not None: learning_rate = FLAGS.learning_rate mllogger.event('opt_adam_beta_1', beta1) mllogger.event('opt_adam_beta_2', beta2) mllogger.event('opt_adam_epsilon', epsilon) optimizer_def = optim.Adam(learning_rate, beta1=beta1, beta2=beta2, eps=epsilon, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(model) del model # don't keep a copy of the initial model # Build parameter partition annotations for preserving partitions from train # to eval. partition_rules = [ (('encoder', 'posembed_input'), partitions.empty_dict), (('decoder', 'posembed_targets'), partitions.empty_dict), (('embedding', ), partitions.spec(num_partitions, 1)), ((r'LayerNorm_\d+', '(bias|scale)'), None), ((r'encoder(decoder)?_norm', '(bias|scale)'), None), ((r'MultiHeadDotProductAttention_\d+', '(query|key|value)', 'kernel'), partitions.spec(1, num_partitions, 1)), ((r'MultiHeadDotProductAttention_\d+', 'out', 'kernel'), partitions.spec(num_partitions, 1, 1)), ((r'MlpBlock_\d+', r'Dense_\d+', 'bias'), None), ((r'MlpBlock_\d+', 'Dense_0', 'kernel'), partitions.spec(1, num_partitions)), ((r'MlpBlock_\d+', 'Dense_1', 'kernel'), partitions.spec(num_partitions, 1)), (('state', 'step'), None), ] optimizer_partitions = optimizer.restore_state( partitions.set_partitions(partition_rules, optimizer.state_dict())) optimizer = broadcast(optimizer) empty_metrics = broadcast({'loss': 0.0, 'accuracy': 0, 'denominator': 0}) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate, warmup_steps=warmup_steps, hidden_size=transformer_kwargs['qkv_dim']) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch', in_axes=(None, 0, 0, 0)) if num_partitions > 1: sharded_predict_step = sharded_jit( predict_step, in_parts=(None, optimizer_partitions.target, None), out_parts=None) else: sharded_predict_step = predict_step if FLAGS.extra_eval_metrics: p_eval_step = jax.pmap(eval_step, axis_name='batch', in_axes=(None, 0)) p_pred_step = jax.pmap(sharded_predict_step, axis_name='batch', in_axes=(0, None, None)) p_allreduce_metrics = jax.pmap(functools.partial(lax.psum, axis_name='batch'), axis_name='batch') def device_train_loop_cond(args): _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body(args): optimizer, dropout_rngs, metrics, token, step, epoch = args input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray( device_train_input_shape, jnp.int32) for _ in train_keys ])) batch = {k: v for k, v in zip(train_keys, input_data)} optimizer, metrics, dropout_rngs = train_step(optimizer, batch, metrics, learning_rate_fn, dropout_rng=dropout_rngs) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): token = lax.create_token(step) optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step if num_partitions > 1: device_train_loop = sharded_jit(device_train_loop, in_parts=(optimizer_partitions, None, None, None, None), out_parts=(optimizer_partitions, None, None, None)) p_train_epoch = jax.pmap(device_train_loop, axis_name='batch', in_axes=(None, 0, 0, None, None)) p_allreduce_metrics_train = functools.partial(lax.psum, axis_name='batch') if num_partitions > 1: p_allreduce_metrics_train = sharded_jit(p_allreduce_metrics_train, in_parts=None, out_parts=None, num_partitions=num_partitions) p_allreduce_metrics_train = jax.pmap(p_allreduce_metrics_train, axis_name='batch') # Precompile all needed computations with fake data so as not to include # compilation time in MLPerf metrics. if FLAGS.precompile: logging.info('precompiling step/epoch functions') if FLAGS.infeed: # the device training loop condition will immediately be false, but # the optimizer tree will be resharded here optimizer, *_ = p_train_epoch(unbroadcast(optimizer), random.split(rng, num_replicas), empty_metrics, jnp.array(0, dtype=jnp.int32), 1) else: metrics = empty_metrics train_input_shape = (num_replicas, batch_size // num_replicas, input_pipeline.MAX_TRAIN_LEN) fake_batch = { k: jnp.ones(train_input_shape, jnp.int32) for k in train_keys } p_train_step(unbroadcast(optimizer), fake_batch, metrics, dropout_rng=random.split(rng, num_replicas)) eval_input_shape = (num_replicas, eval_batch_size // num_replicas, input_pipeline.MAX_EVAL_LEN) fake_eval_batch = { 'inputs': jnp.ones(eval_input_shape, jnp.int32), 'targets': jnp.ones(eval_input_shape, jnp.int32), } if FLAGS.extra_eval_metrics: p_eval_step(unbroadcast(optimizer.target), fake_eval_batch) fake_cache = cache_def.initialize_cache( (eval_input_shape[1], FLAGS.max_predict_length)) maybe_start_xprof(20) p_pred_step(fake_eval_batch['inputs'], unbroadcast(optimizer.target), fake_cache) time.sleep(20) sync_devices() fake_bleu_1 = np.zeros((4, ), dtype=np.int32) fake_bleu_2 = np.zeros((), dtype=np.int32) per_host_sum_pmap((fake_bleu_1, fake_bleu_1, fake_bleu_2, fake_bleu_2)) sync_devices() p_allreduce_metrics_train(empty_metrics) sync_devices() logging.info('finished precompiling step/epoch functions') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, num_replicas) # Record time-0 metrics for proper tensorboard plot x-axis scaling. if jax.host_id() == 0: if FLAGS.compute_train_metrics: train_summary_writer.scalar('loss', 9.999, 0) train_summary_writer.scalar('accuracy', 0.0, 0) train_summary_writer.flush() eval_summary_writer.scalar('bleu', 0.0, 0) eval_summary_writer.flush() train_ds = input_pipeline.get_wmt_dataset(batch_size=batch_size // jax.host_count(), train=True) eval_ds = input_pipeline.get_wmt_dataset(batch_size=eval_batch_size, train=False) train_iter = iter(train_ds) eval_iter = iter(eval_ds) local_devices = jax.local_devices() maybe_start_xprof(max(30, 60 / (jax.device_count() / 2048))) host_step, device_step = 0, broadcast(0) gc.disable() mllogger.end('init_stop') if jax.host_id() == 0: mllogger.start('run_start') for epoch in range(FLAGS.num_epochs): if jax.host_id() == 0 and not BLEU_THRESHOLD_REACHED: mllogger.start('block_start', metadata={ 'first_epoch_num': epoch + 1, 'epoch_count': 1 }) metrics = empty_metrics if FLAGS.infeed: optimizer, dropout_rngs, metrics, device_step = p_train_epoch( unbroadcast(optimizer), dropout_rngs, metrics, unbroadcast(device_step), epoch) while int(host_step // steps_per_epoch) == epoch: # pylint: disable=protected-access batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # Shard data to devices and do a training step. batch = jax.tree_map( lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), batch) if FLAGS.infeed: for i, device in enumerate(local_devices): replica_id = i // num_partitions input_tuple = tuple( [batch[k][replica_id] for k in train_keys]) assert input_tuple[0].shape == device_train_input_shape, ( 'infeed shape error %s != %s' % (input_tuple[0].shape, device_train_input_shape)) assert input_tuple[0].dtype == jnp.int32, ( 'infeed dtype error %s != %s' % (input_tuple[0].dtype, jnp.int32)) infeed_pool.submit( functools.partial(device.transfer_to_infeed, input_tuple)) else: optimizer, metrics, dropout_rngs = p_train_step( unbroadcast(optimizer), batch, metrics, dropout_rng=dropout_rngs) host_step += 1 if FLAGS.compute_train_metrics: metrics = p_allreduce_metrics_train(metrics) # Schedule training metric handling. summary_thread.submit( functools.partial(write_train_summary, metrics, train_summary_writer, host_step)) # Optional, extra evaluation metrics. if FLAGS.extra_eval_metrics: eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(unbroadcast(optimizer.target), eval_batch) eval_metrics.append(metrics) eval_metrics = p_allreduce_metrics(eval_metrics) # Schedule metric summarization/logging. summary_thread.submit( functools.partial(write_eval_summary, eval_metrics, eval_summary_writer, host_step)) # Translation and BLEU Score. all_predicted, all_targets, all_bs = [], [], [] for i in range(pred_batches): # pylint: disable=protected-access pred_batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) logging.info('Predicting on input of shape %s.', str(pred_batch['inputs'].shape)) # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size != eval_batch_size: logging.info('Translation: uneven batch size %d.', cur_pred_batch_size) pred_batch = jax.tree_map( lambda x: pad_examples(x, eval_batch_size), pred_batch) pred_batch = jax.tree_map( lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache = cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length)) all_predicted.append( p_pred_step(pred_batch['inputs'], unbroadcast(optimizer.target), cache)) all_targets.append(pred_batch['targets']) all_bs.append(cur_pred_batch_size) # Schedule BLEU calculation and summarization/logging. # We use the ICI as part of BLEU score computation, so we call this from the # main thread so the BLEU pmap runs before the next train epoch pmap write_predict_summary(all_predicted, all_targets, all_bs, target_encoder, eval_summary_writer, epoch, host_step, summary_thread) # Wait until computations are done before exiting sync_devices() if jax.host_id() == 0: summary_thread.shutdown() if not BLEU_THRESHOLD_REACHED: mllogger.end('run_stop', metadata={'status': 'aborted'})
def main(_): tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') # Performance gains on TPU by switching to hardware bernoulli. def hardware_bernoulli(rng_key, p=jax.numpy.float32(0.5), shape=None): lax_key = jax.lax.tie_in(rng_key, 0.0) return jax.lax.rng_uniform(lax_key, 1.0, shape) < p def set_hardware_bernoulli(): jax.random.bernoulli = hardware_bernoulli set_hardware_bernoulli() # As we gridsearch the weight decay and the learning rate, we add them to the # output directory path so that each model has its own directory to save the # results in. We also add the `run_seed` which is "gridsearched" on to # replicate an experiment several times. output_dir_suffix = os.path.join( 'lr_' + str(FLAGS.learning_rate), 'wd_' + str(FLAGS.weight_decay), 'rho_' + str(FLAGS.sam_rho), 'seed_' + str(FLAGS.run_seed)) output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix) if not gfile.exists(output_dir): gfile.makedirs(output_dir) num_devices = jax.local_device_count() * jax.host_count() assert FLAGS.batch_size % num_devices == 0 local_batch_size = FLAGS.batch_size // num_devices info = 'Total batch size: {} ({} x {} replicas)'.format( FLAGS.batch_size, local_batch_size, num_devices) logging.info(info) if FLAGS.dataset == 'cifar10': if FLAGS.from_pretrained_checkpoint: image_size = efficientnet.name_to_image_size(FLAGS.model_name) else: image_size = None dataset_source = dataset_source_lib.Cifar10( FLAGS.batch_size // jax.host_count(), FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations, image_size=image_size) elif FLAGS.dataset == 'cifar100': if FLAGS.from_pretrained_checkpoint: image_size = efficientnet.name_to_image_size(FLAGS.model_name) else: image_size = None dataset_source = dataset_source_lib.Cifar100( FLAGS.batch_size // jax.host_count(), FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations, image_size=image_size) elif FLAGS.dataset == 'fashion_mnist': dataset_source = dataset_source_lib.FashionMnist( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset == 'svhn': dataset_source = dataset_source_lib.SVHN( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset == 'imagenet': imagenet_image_size = efficientnet.name_to_image_size(FLAGS.model_name) dataset_source = dataset_source_imagenet.Imagenet( FLAGS.batch_size // jax.host_count(), imagenet_image_size, FLAGS.image_level_augmentations) else: raise ValueError('Dataset not recognized.') if 'cifar' in FLAGS.dataset or 'svhn' in FLAGS.dataset: if image_size is None or 'svhn' in FLAGS.dataset: image_size = 32 num_channels = 3 num_classes = 100 if FLAGS.dataset == 'cifar100' else 10 elif FLAGS.dataset == 'fashion_mnist': image_size = 28 # For Fashion Mnist num_channels = 1 num_classes = 10 elif FLAGS.dataset == 'imagenet': image_size = imagenet_image_size num_channels = 3 num_classes = 1000 else: raise ValueError('Dataset not recognized.') try: model, state = load_imagenet_model.get_model(FLAGS.model_name, local_batch_size, image_size, num_classes) except load_imagenet_model.ModelNameError: model, state = load_model.get_model(FLAGS.model_name, local_batch_size, image_size, num_classes, num_channels) # Learning rate will be overwritten by the lr schedule, we set it to zero. optimizer = flax_training.create_optimizer(model, 0.0) flax_training.train(optimizer, state, dataset_source, output_dir, FLAGS.num_epochs)
def main(argv): del argv # Unused arg. tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) if FLAGS.version2: per_core_bs_train = FLAGS.per_core_batch_size // ( FLAGS.ensemble_size * FLAGS.num_train_samples) per_core_bs_eval = FLAGS.per_core_batch_size // ( FLAGS.ensemble_size * FLAGS.num_eval_samples) else: per_core_bs_train = FLAGS.per_core_batch_size // FLAGS.num_train_samples per_core_bs_eval = FLAGS.per_core_batch_size // FLAGS.num_eval_samples batch_size_train = per_core_bs_train * FLAGS.num_cores batch_size_eval = per_core_bs_eval * FLAGS.num_cores logging.info('Saving checkpoints at %s', FLAGS.output_dir) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) train_input_fn = utils.load_input_fn(split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=per_core_bs_train, use_bfloat16=FLAGS.use_bfloat16, normalize=False) clean_test_input_fn = utils.load_input_fn(split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=per_core_bs_eval, use_bfloat16=FLAGS.use_bfloat16, normalize=False) train_dataset = strategy.experimental_distribute_datasets_from_function( train_input_fn) test_datasets = { 'clean': strategy.experimental_distribute_datasets_from_function( clean_test_input_fn), } if FLAGS.corruptions_interval > 0: if FLAGS.dataset == 'cifar10': load_c_input_fn = utils.load_cifar10_c_input_fn else: load_c_input_fn = functools.partial(utils.load_cifar100_c_input_fn, path=FLAGS.cifar100_c_path) corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for corruption in corruption_types: for intensity in range(1, max_intensity + 1): input_fn = load_c_input_fn(corruption_name=corruption, corruption_intensity=intensity, batch_size=per_core_bs_eval, use_bfloat16=FLAGS.use_bfloat16, normalize=False) test_datasets['{0}_{1}'.format(corruption, intensity)] = ( strategy.experimental_distribute_datasets_from_function( input_fn)) ds_info = tfds.builder(FLAGS.dataset).info train_dataset_size = ds_info.splits['train'].num_examples test_dataset_size = ds_info.splits['test'].num_examples num_classes = ds_info.features['label'].num_classes steps_per_epoch = train_dataset_size // batch_size_train steps_per_eval = test_dataset_size // batch_size_eval if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building Keras ResNet-32 model') model = resnet_cifar_model.rank1_resnet_v1( input_shape=ds_info.features['image'].shape, depth=32, num_classes=num_classes, width_multiplier=4, alpha_initializer=FLAGS.alpha_initializer, gamma_initializer=FLAGS.gamma_initializer, alpha_regularizer=FLAGS.alpha_regularizer, gamma_regularizer=FLAGS.gamma_regularizer, use_additive_perturbation=FLAGS.use_additive_perturbation, ensemble_size=FLAGS.ensemble_size, random_sign_init=FLAGS.random_sign_init, dropout_rate=FLAGS.dropout_rate) logging.info(model.summary()) base_lr = FLAGS.base_learning_rate * batch_size_train / 128 lr_decay_epochs = [(start_epoch * FLAGS.train_epochs) // 200 for start_epoch in FLAGS.lr_decay_epochs] lr_schedule = utils.LearningRateSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/loss': tf.keras.metrics.Mean(), } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, max_intensity + 1): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( ed.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins)) test_diversity = {} training_diversity = {} if FLAGS.ensemble_size > 1: for i in range(FLAGS.ensemble_size): metrics['test/nll_member_{}'.format( i)] = tf.keras.metrics.Mean() metrics['test/accuracy_member_{}'.format(i)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) test_diversity = { 'test/disagreement': tf.keras.metrics.Mean(), 'test/average_kl': tf.keras.metrics.Mean(), 'test/cosine_similarity': tf.keras.metrics.Mean(), } training_diversity = { 'train/disagreement': tf.keras.metrics.Mean(), 'train/average_kl': tf.keras.metrics.Mean(), 'train/cosine_similarity': tf.keras.metrics.Mean(), } checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.version2 and FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) if not (FLAGS.member_sampling or FLAGS.expected_probs): labels = tf.tile(labels, [FLAGS.ensemble_size]) if FLAGS.num_train_samples > 1: images = tf.tile(images, [FLAGS.num_train_samples, 1, 1, 1]) with tf.GradientTape() as tape: logits = model(images, training=True) probs = tf.nn.softmax(logits) # Diversity evaluation. if FLAGS.version2 and FLAGS.ensemble_size > 1: per_probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = ed.metrics.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) if FLAGS.num_train_samples > 1: probs = tf.reshape( probs, tf.concat( [[FLAGS.num_train_samples, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) if FLAGS.member_sampling and FLAGS.version2 and FLAGS.ensemble_size > 1: idx = tf.random.uniform([], maxval=FLAGS.ensemble_size, dtype=tf.int64) idx_one_hot = tf.expand_dims( tf.one_hot(idx, FLAGS.ensemble_size, dtype=probs.dtype), 0) probs_shape = probs.shape probs = tf.reshape(probs, [FLAGS.ensemble_size, -1]) probs = tf.matmul(idx_one_hot, probs) probs = tf.reshape(probs, tf.concat([[-1], probs_shape[1:]], 0)) elif FLAGS.expected_probs and FLAGS.version2 and FLAGS.ensemble_size > 1: probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, probs)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the slow weights and bias terms. This excludes BN # parameters and fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) / train_dataset_size kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= FLAGS.kl_annealing_steps kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. grad_list = [] if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = list(zip(grads, model.trainable_variables)) for vec, var in grads_and_vars: # Apply different learning rate on the fast weight approximate # posterior/prior parameters. This is excludes BN and slow weights, # but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grad_list.append( (vec * FLAGS.fast_weight_lr_multiplier, var)) else: grad_list.append((vec, var)) optimizer.apply_gradients(grad_list) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) if FLAGS.version2 and FLAGS.ensemble_size > 1: for k, v in diversity_results.items(): training_diversity['train/' + k].update_state(v) strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) if FLAGS.num_eval_samples > 1: images = tf.tile(images, [FLAGS.num_eval_samples, 1, 1, 1]) logits = model(images, training=False) probs = tf.nn.softmax(logits) if FLAGS.num_eval_samples > 1: probs = tf.reshape( probs, tf.concat([[FLAGS.num_eval_samples, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) if FLAGS.ensemble_size > 1: per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) if dataset_name == 'clean': per_probs_tensor = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = ed.metrics.average_pairwise_diversity( per_probs_tensor, FLAGS.ensemble_size) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_nll = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_nll) metrics['test/accuracy_member_{}'.format( i)].update_state(labels, member_probs) probs = tf.reduce_mean(per_probs, axis=0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) filtered_variables = [] for var in model.trainable_variables: if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1, ))) kl = sum(model.losses) / test_dataset_size l2_loss = kl + FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss = negative_log_likelihood + l2_loss if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) metrics['test/loss'].update_state(loss) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) work_unit.set_notes(message) if step % 20 == 0: logging.info(message) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator, dataset_name) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types, max_intensity) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) for i in range(FLAGS.ensemble_size): logging.info( 'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) total_metrics = itertools.chain(metrics.items(), training_diversity.items(), test_diversity.items()) total_results = { name: metric.result() for name, metric in total_metrics } total_results.update(corrupt_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for name, result in total_results.items(): name = name.replace('/', '_') if 'negative_log_likelihood' in name: # Plots sort WIDs from high-to-low so look at maximization objectives. name = name.replace('negative_log_likelihood', 'log_likelihood') result = -result objective = work_unit.get_measurement_series(name) objective.create_measurement(result, epoch + 1) for _, metric in total_metrics: metric.reset_states() summary_writer.flush() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name)
def main(_): tf.enable_v2_behavior() ############################################################################## ######################### Data loading and processing ######################## ############################################################################## print('Loading data') with gfile.GFile(transition_path, 'r') as f: transitions = np.load(f) if np.max(transitions) > 1.0: transitions = transitions / 255.0 with gfile.GFile(synthetic_transition_path, 'r') as f: synthetic_transitions = np.load(f) if np.max(synthetic_transitions) > 1.0: synthetic_transitions = synthetic_transitions / 255.0 with gfile.GFile(transition_label_path, 'r') as f: captions = pickle.load(f) with gfile.GFile(synthetic_transition_label_path, 'r') as f: synthetic_captions = pickle.load(f) with gfile.GFile(vocab_path, 'r') as f: vocab_list = f.readlines() vocab_list = [w[:-1].decode('utf-8') for w in vocab_list] vocab_list = ['eos', 'sos'] + vocab_list v2i, i2v = wv.create_look_up_table(vocab_list) encode_fn = wv.encode_text_with_lookup_table(v2i) decode_fn = wv.decode_with_lookup_table(i2v) encoded_captions = [] for all_cp in captions: for cp in all_cp: cp = 'sos ' + cp + ' eos' encoded_captions.append(np.array(encode_fn(cp))) synthetic_encoded_captions = [] for all_cp in synthetic_captions: for cp in all_cp: cp = 'sos ' + cp + ' eos' synthetic_encoded_captions.append(np.array(encode_fn(cp))) all_caption_n = len(encoded_captions) all_synthetic_caption_n = len(synthetic_encoded_captions) encoded_captions = np.array(encoded_captions) encoded_captions = pad_to_max_length(encoded_captions, max_l=15) synthetic_encoded_captions = np.array(synthetic_encoded_captions) synthetic_encoded_captions = pad_to_max_length(synthetic_encoded_captions, max_l=15) obs_idx, caption_idx, negative_caption_idx = [], [], [] curr_caption_idx = 0 for i, _ in enumerate(transitions): for cp in captions[i]: obs_idx.append(i) if 'nothing' not in cp: caption_idx.append(curr_caption_idx) else: negative_caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_caption_n synthetic_obs_idx, synthetic_caption_idx = [], [] synthetic_negative_caption_idx = [] curr_caption_idx = 0 for i, _ in enumerate(synthetic_transitions): for cp in synthetic_captions[i]: synthetic_obs_idx.append(i) if 'nothing' not in cp: synthetic_caption_idx.append(curr_caption_idx) else: synthetic_negative_caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_synthetic_caption_n obs_idx = np.array(obs_idx) caption_idx = np.array(caption_idx) negative_caption_idx = np.array(negative_caption_idx) all_idx = np.arange(len(caption_idx)) train_idx = all_idx[:int(len(all_idx) * 0.8)] test_idx = all_idx[int(len(all_idx) * 0.8):] print('Number of training examples: {}'.format(len(train_idx))) print('Number of test examples: {}\n'.format(len(test_idx))) synthetic_obs_idx = np.array(synthetic_obs_idx) synthetic_caption_idx = np.array(synthetic_caption_idx) synthetic_negative_caption_idx = np.array(synthetic_negative_caption_idx) synthetic_all_idx = np.arange(len(synthetic_caption_idx)) synthetic_train_idx = synthetic_all_idx[:int(len(synthetic_all_idx) * 0.8)] synthetic_test_idx = synthetic_all_idx[int(len(synthetic_all_idx) * 0.8):] print('Number of synthetic training examples: {}'.format( len(synthetic_train_idx))) print('Number of synthetic test examples: {}\n'.format( len(synthetic_test_idx))) def sample_batch(data_type, batch_size, mode='train'): is_synthetic = data_type == 'synthetic' transitions_s = synthetic_transitions if is_synthetic else transitions encoded_captions_s = synthetic_encoded_captions if is_synthetic else encoded_captions obs_idx_s = synthetic_obs_idx if is_synthetic else obs_idx caption_idx_s = synthetic_caption_idx if is_synthetic else caption_idx all_idx_s = synthetic_all_idx if is_synthetic else all_idx train_idx_s = synthetic_train_idx if is_synthetic else train_idx test_idx_s = synthetic_test_idx if is_synthetic else test_idx if mode == 'train': batch_idx_s = np.random.choice(train_idx_s, size=batch_size) else: batch_idx_s = np.random.choice(test_idx_s, size=batch_size) input_tensor = tf.convert_to_tensor( np.concatenate([ transitions_s[obs_idx_s[batch_idx_s], 1, :], transitions_s[obs_idx_s[batch_idx_s], 1, :] ])) positive_idx = caption_idx_s[batch_idx_s] negative_idx = caption_idx_s[np.random.choice(train_idx_s, size=batch_size)] caption_tensor = tf.convert_to_tensor( np.concatenate([ encoded_captions_s[positive_idx], encoded_captions_s[negative_idx] ], axis=0)) target_tensor = tf.convert_to_tensor( np.float32( np.concatenate([np.ones(batch_size), np.zeros(batch_size)], axis=0))) return input_tensor, caption_tensor, target_tensor ############################################################################## ############################# Training Setup ################################# ############################################################################## embedding_dim = 32 units = 64 vocab_size = len(vocab_list) batch_size = 64 max_sequence_length = 15 encoder_config = {'name': 'image', 'embedding_dim': 64} decoder_config = { 'name': 'attention', 'word_embedding_dim': 64, 'hidden_units': 256, 'vocab_size': len(vocab_list), } encoder = get_answering_encoder(encoder_config) decoder = get_answering_decoder(decoder_config) projection_layer = tf.keras.layers.Dense(1, activation='sigmoid', name='answering_projection') optimizer = tf.keras.optimizers.Adam(1e-4) bce = tf.keras.losses.BinaryCrossentropy() @tf.function def compute_loss(obs, instruction, target, training): print('Build compute loss...') instruction = tf.expand_dims(instruction, axis=-1) hidden = decoder.reset_state(batch_size=target.shape[0]) features = encoder(obs, training=training) for i in tf.range(max_sequence_length): _, hidden, _ = decoder(instruction[:, i], features, hidden, training=training) projection = tf.squeeze(projection_layer(hidden), axis=1) loss = bce(target, projection) return loss, projection @tf.function def train_step(obs, instruction, target): print('Build train step...') with tf.GradientTape() as tape: loss, _ = compute_loss(obs, instruction, target, True) trainable_variables = encoder.trainable_variables + decoder.trainable_variables + projection_layer.trainable_variables print('num trainable: ', len(trainable_variables)) gradients = tape.gradient(loss, trainable_variables) optimizer.apply_gradients(zip(gradients, trainable_variables)) return loss ############################################################################## ############################# Training Loop ################################## ############################################################################## print('Start training...\n') start_epoch = 0 if FLAGS.save_dir: checkpoint_path = FLAGS.save_dir ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, projection_layer=projection_layer, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) if ckpt_manager.latest_checkpoint: start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1]) epochs = 400 step_per_epoch = int(all_caption_n / batch_size) previous_best, previous_best_accuracy = 100., 0.0 # input_tensor, instruction, target = sample_batch('synthetic', batch_size, # 'train') for epoch in range(start_epoch, epochs): start = time.time() total_loss = 0 for batch in range(step_per_epoch): input_tensor, instruction, target = sample_batch( 'synthetic', batch_size, 'train') batch_loss = train_step(input_tensor, instruction, target) total_loss += batch_loss # print(batch, batch_loss) # print(instruction[0]) # print(encode_fn('nothing')) # print('====================================') if batch % 1000 == 0: print('Epoch {} Batch {} Loss {:.4f}'.format( epoch, batch, batch_loss.numpy())) if epoch % 5 == 0 and FLAGS.save_dir: test_total_loss = 0 accuracy = 0 for batch in range(10): input_tensor, instruction, target = sample_batch( 'synthetic', batch_size, 'test') t_loss, prediction = compute_loss(input_tensor, instruction, target, False) test_total_loss += t_loss accuracy += np.mean( np.float32(np.float32(prediction > 0.5) == target)) test_total_loss /= 10. accuracy /= 10. if accuracy > previous_best_accuracy: previous_best_accuracy, previous_best = accuracy, test_total_loss ckpt_manager.save(checkpoint_number=epoch) print('\nEpoch {} | Loss {:.6f} | Val loss {:.6f} | Accuracy {:.3f}'. format(epoch + 1, total_loss / step_per_epoch, previous_best, previous_best_accuracy)) print('Time taken for 1 epoch {:.6f} sec\n'.format(time.time() - start)) if epoch % 10 == 0: test_total_loss = 0 accuracy = 0 for batch in range(len(test_idx) // batch_size): input_tensor, instruction, target = sample_batch( 'synthetic', batch_size, 'test') t_loss, prediction = compute_loss(input_tensor, instruction, target, training=False) test_total_loss += t_loss accuracy += np.mean( np.float32(np.float32(prediction > 0.5) == target)) test_total_loss /= (len(test_idx) // batch_size) accuracy /= (len(test_idx) // batch_size) if accuracy > previous_best_accuracy and FLAGS.save_dir: previous_best_accuracy, previous_best = accuracy, test_total_loss ckpt_manager.save(checkpoint_number=epoch) print('\n====================================================') print('Test Loss {:.6f} | Test Accuracy {:.3f}'.format( test_total_loss, accuracy)) print('====================================================\n')
def main(unused_argv): assert FLAGS.data is not None, 'Provide training data path via --data.' tf.enable_v2_behavior() batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE training_steps_per_epoch = FLAGS.steps_per_epoch or ( int(APPROX_IMAGENET_TRAINING_IMAGES // batch_size)) validation_steps = int( math.ceil(1.0 * IMAGENET_VALIDATION_IMAGES / batch_size)) model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR logging.info('Saving tensorboard summaries at %s', model_dir) logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) tf.config.experimental_connect_to_host(resolver.master()) # pylint: disable=line-too-long tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) logging.info('Use bfloat16: %s.', USE_BFLOAT16) logging.info('Use global batch size: %s.', batch_size) logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy) logging.info('Training model using data in directory "%s".', FLAGS.data) with tf.device('/job:worker'): with strategy.scope(): logging.info('Building Keras ResNet-50 model') model = resnet_model.ResNet50(num_classes=NUM_CLASSES) logging.info('Compiling model.') metrics = ['sparse_categorical_accuracy'] if FLAGS.eval_top_5_accuracy: metrics.append(sparse_top_k_categorical_accuracy) model.compile( optimizer=tf.keras.optimizers.SGD( learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True), loss='sparse_categorical_crossentropy', metrics=metrics) imagenet_train = imagenet_input.ImageNetInput( is_training=True, data_dir=FLAGS.data, batch_size=batch_size, use_bfloat16=USE_BFLOAT16) imagenet_eval = imagenet_input.ImageNetInput( is_training=False, data_dir=FLAGS.data, batch_size=batch_size, use_bfloat16=USE_BFLOAT16) lr_schedule_cb = LearningRateBatchScheduler( schedule=learning_rate_schedule_wrapper(training_steps_per_epoch)) tensorboard_cb = tf.keras.callbacks.TensorBoard( log_dir=model_dir) training_callbacks = [lr_schedule_cb, tensorboard_cb] model.fit( imagenet_train.input_fn(), epochs=FLAGS.num_epochs, steps_per_epoch=training_steps_per_epoch, callbacks=training_callbacks, validation_data=imagenet_eval.input_fn(), validation_steps=validation_steps, validation_freq=5) model_saving_utils.save_model(model, model_dir, WEIGHTS_TXT)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() emb_size = FLAGS.emb_size if FLAGS.arch == 'resnet50': model_def = model_resnet.ResNet50.partial(num_outputs=emb_size) feature_size = 64 * 8 * 4 elif FLAGS.arch == 'resnet101': model_def = model_resnet.ResNet101.partial(num_outputs=emb_size) feature_size = 64 * 8 * 4 elif FLAGS.arch == 'resnet152': model_def = model_resnet.ResNet152.partial(num_outputs=emb_size) feature_size = 64 * 8 * 4 else: raise ValueError if FLAGS.lr_moco_sched_steps: lr_moco_sched_steps = ast.literal_eval(FLAGS.lr_moco_sched_steps) else: lr_moco_sched_steps = [[120, 0.1], [160, 0.01]] if FLAGS.lr_clf_sched_steps: lr_clf_sched_steps = ast.literal_eval(FLAGS.lr_clf_sched_steps) else: lr_clf_sched_steps = [[60, 0.2], [75, 0.04], [90, 0.008]] def make_moco_lr_fun(base_lr, steps_per_epoch): return lr_schedule.create_stepped_learning_rate_schedule( base_lr, steps_per_epoch, lr_moco_sched_steps, warmup_length=FLAGS.lr_moco_sched_warmup) def make_clf_lr_fun(base_lr, steps_per_epoch): return lr_schedule.create_stepped_learning_rate_schedule( base_lr, steps_per_epoch, lr_clf_sched_steps, warmup_length=FLAGS.lr_clf_sched_warmup) train(model_def, model_dir=FLAGS.model_dir, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, num_moco_epochs=FLAGS.num_moco_epochs, num_clf_epochs=FLAGS.num_clf_epochs, moco_learning_rate=FLAGS.moco_learning_rate, clf_learning_rate=FLAGS.clf_learning_rate, sgd_momentum=FLAGS.sgd_momentum, sgd_nesterov=FLAGS.sgd_nesterov, make_moco_lr_fun=make_moco_lr_fun, make_clf_lr_fun=make_clf_lr_fun, moco_l2_reg=FLAGS.moco_l2_reg, clf_l2_reg=FLAGS.clf_l2_reg, feature_size=feature_size, moco_momentum=FLAGS.moco_momentum, emb_size=emb_size, moco_temperature=FLAGS.moco_temperature, dictionary_size=FLAGS.dictionary_size, run_seed=FLAGS.rng)
def main(argv): del argv # BEGIN GOOGLE-INTERNAL xm.setup_work_unit() # END GOOGLE-INTERNAL tf.enable_v2_behavior() init_mllogger() mllogger.event('cache_clear') mllogger.start('init_start') mllogger.event('submission_org', 'Google') mllogger.event('submission_platform', 'TPUv3-{}'.format(jax.device_count())) mllogger.event('submission_division', 'closed') mllogger.event('submission_status', 'research') mllogger.event('submission_benchmark', 'resnet') mllogger.event('train_samples', input_pipeline.TRAIN_IMAGES) mllogger.event('eval_samples', input_pipeline.EVAL_IMAGES) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir) # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') if FLAGS.seed is not None: seed = FLAGS.seed else: seed = np.uint32(time.time() if jax.host_id() == 0 else 0) seed = per_host_sum_pmap(seed) mllogger.event('seed', int(seed)) key = random.PRNGKey(seed) batch_size = FLAGS.batch_size if batch_size == -1: if jax.device_count() > 4096: batch_size = 65536 else: batch_size = min(128 * jax.device_count(), 32768) mllogger.event('global_batch_size', batch_size) eval_batch_size = min(input_pipeline.EVAL_IMAGES, 256 * jax.device_count()) device_batch_size = batch_size // jax.device_count() device_eval_batch_size = int( math.ceil(eval_batch_size / jax.device_count())) model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32 input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32 num_epochs = FLAGS.num_epochs if num_epochs is None: if batch_size < 32768: num_epochs = 56 elif batch_size < 65536: num_epochs = 64 else: num_epochs = 92 steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size # match TF submission behavior (round steps per loop up) steps_per_loop = int(math.ceil(steps_per_epoch * FLAGS.epochs_per_loop)) # also apply rounding loop up to next step to "epochs" in LR schedule steps_per_epoch *= steps_per_loop / (steps_per_epoch * FLAGS.epochs_per_loop) steps_per_eval = int( math.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size)) base_learning_rate = FLAGS.learning_rate * batch_size / 256. beta = FLAGS.momentum if beta is None: if batch_size < 32768: beta = 0.9 elif batch_size < 65536: beta = 0.929 else: beta = 0.9537213777059405 weight_decay = FLAGS.weight_decay if weight_decay is None: weight_decay = 2e-4 if batch_size < 32768 else 1e-4 space_to_depth = FLAGS.space_to_depth if space_to_depth is None: space_to_depth = device_batch_size <= 8 image_format = FLAGS.image_format if image_format is None: if space_to_depth and device_batch_size <= 8: image_format = 'HWNC' else: image_format = 'HWCN' image_size = input_pipeline.IMAGE_SIZE if space_to_depth: train_input_shape = (device_batch_size, image_size // 2, image_size // 2, 12) eval_input_shape = (device_eval_batch_size, image_size // 2, image_size // 2, 12) else: train_input_shape = (device_batch_size, image_size, image_size, 3) eval_input_shape = (device_eval_batch_size, image_size, image_size, 3) if image_format == 'HWCN': train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 3, 0]) eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 3, 0]) elif image_format == 'HWNC': train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 0, 3]) eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 0, 3]) model, state = create_model(key, device_batch_size, image_size, model_dtype, space_to_depth) if FLAGS.lars: mllogger.event('opt_name', 'lars') mllogger.event('lars_opt_weight_decay', weight_decay) mllogger.event('lars_opt_momentum', beta) mllogger.event('lars_epsilon', 0) weight_opt_def = optim.LARS(base_learning_rate, beta, weight_decay=weight_decay) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=False) learning_rate_fn = polynomial_learning_rate_fn(batch_size, steps_per_epoch, num_epochs) else: mllogger.event('opt_name', 'sgd') mllogger.event('sgd_opt_momentum', beta) weight_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=weight_decay, nesterov=True) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=True) learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optim.ModelParamTraversal(filter_weights) other_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) optimizer = optimizer_def.create(model) del model # do not keep a copy of the initial model optimizer = broadcast(optimizer) state = broadcast(state) empty_metrics = broadcast({'samples': 0, 'loss': 0., 'accuracy': 0}) p_allreduce_metrics = jax.pmap(allreduce_metrics, axis_name='batch') p_sync_batchnorm_stats = jax.pmap(sync_batchnorm_stats, axis_name='batch') def host_loop_train_step(optimizer, state, metrics): token = lax.create_token(optimizer.state[0].step) batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) return optimizer, state, metrics p_host_loop_train_step = jax.pmap(host_loop_train_step, axis_name='batch', in_axes=(None, 0, 0)) def host_loop_eval_step(model, state, metrics): token = lax.create_token(metrics['samples']) batch, token = lax.infeed( token, shape=(jax.ShapedArray(eval_input_shape, model_dtype), jax.ShapedArray((device_eval_batch_size, ), jnp.int32))) metrics = eval_step(model, state, batch, metrics, image_format, space_to_depth) return metrics p_host_loop_eval_step = jax.pmap(host_loop_eval_step, axis_name='batch', in_axes=(None, None, 0)) def device_train_loop_cond(args): _, _, _, _, step, loop = args return step // steps_per_loop == loop def device_train_loop_body(args): optimizer, state, metrics, token, step, loop = args batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) step += 1 return optimizer, state, metrics, token, step, loop def device_train_loop(optimizer, state, metrics, step, loop): token = lax.create_token(step) optimizer, state, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, state, metrics, token, step, loop)) state = sync_batchnorm_stats(state) metrics = allreduce_metrics(metrics) return optimizer, state, metrics, step p_train_loop = jax.pmap(device_train_loop, axis_name='batch', in_axes=(None, None, 0, None, None)) # BEGIN GOOGLE-INTERNAL def maybe_start_xprof(seconds): if jax.host_id() == 0 and FLAGS.xprof: xprof = xprof_session.XprofSession() xprof.start_session('REDACTED', True, 2) def sleep_and_end_xprof(): time.sleep(seconds) logging.info( 'Xprof URL: %s', xprof.end_session_and_get_url( tag='flax resnet, {} devices, batch {} per device'. format(jax.device_count(), device_batch_size))) thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof) # END GOOGLE-INTERNAL if FLAGS.precompile: logging.info('precompiling step/loop functions') if FLAGS.device_loop: # the device training loop condition will immediately be false p_train_loop(unbroadcast(optimizer), unbroadcast(state), empty_metrics, jnp.array(0, dtype=jnp.int32), 1) else: for device in jax.local_devices(): images = np.zeros(train_input_shape, model_dtype) labels = np.zeros((device_batch_size, ), np.int32) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) p_host_loop_train_step(unbroadcast(optimizer), state, empty_metrics) p_sync_batchnorm_stats(state) for device in jax.local_devices(): images = np.zeros(eval_input_shape, model_dtype) labels = np.zeros((device_eval_batch_size, ), np.int32) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) p_host_loop_eval_step(unbroadcast(optimizer.target), unbroadcast(state), empty_metrics) p_allreduce_metrics(empty_metrics)['accuracy'].block_until_ready() logging.info('finished precompiling') # BEGIN GOOGLE-INTERNAL maybe_start_xprof(20) # END GOOGLE-INTERNAL if not FLAGS.fake_data: logging.info('constructing datasets') # pylint: disable=g-complex-comprehension train_ds, eval_ds = [ input_pipeline.load_split( device_batch_size if train else device_eval_batch_size, dtype=input_dtype, train=train, image_format=image_format, space_to_depth=space_to_depth, cache_uncompressed=jax.device_count() > 64) for train in (True, False) ] logging.info('constructing dataset iterators') train_iter = iter(train_ds) eval_iter = iter(eval_ds) local_devices = jax.local_devices() host_step, device_step = 0, broadcast(0) mllogger.end('init_stop') mllogger.start('run_start') mllogger.start('block_start', metadata={ 'first_epoch_num': 1, 'epoch_count': FLAGS.epochs_per_loop }) for loop in range(int(math.ceil(num_epochs / FLAGS.epochs_per_loop)) + 2): # BEGIN GOOGLE-INTERNAL if loop == 10: maybe_start_xprof(1) # END GOOGLE-INTERNAL metrics = empty_metrics if FLAGS.device_loop: optimizer, state, metrics, device_step = p_train_loop( unbroadcast(optimizer), unbroadcast(state), metrics, unbroadcast(device_step), loop) while int(host_step // steps_per_loop) == loop: if not FLAGS.device_loop: optimizer, state, metrics = p_host_loop_train_step( unbroadcast(optimizer), state, metrics) # pylint: disable=protected-access while infeed_pool._work_queue.qsize() > 100: time.sleep(0.01) for device in local_devices: if FLAGS.fake_data: images = np.zeros(train_input_shape, model_dtype) labels = np.zeros((device_batch_size, ), np.int32) else: # pylint: disable=protected-access images, labels = jax.tree_map(lambda x: x._numpy(), next(train_iter)) assert images.shape == train_input_shape and labels.dtype == jnp.int32 infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) host_step += 1 epoch = (loop + 1) * FLAGS.epochs_per_loop if FLAGS.train_metrics: if not FLAGS.device_loop: metrics = p_allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'train', epoch)) if not FLAGS.device_loop: state = p_sync_batchnorm_stats(state) metrics = empty_metrics for _ in range(steps_per_eval): metrics = p_host_loop_eval_step(unbroadcast(optimizer.target), unbroadcast(state), metrics) for device in local_devices: if FLAGS.fake_data: images = np.zeros(eval_input_shape, model_dtype) labels = np.zeros((device_eval_batch_size, ), np.int32) else: # pylint: disable=protected-access images, labels = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) assert images.shape == eval_input_shape and labels.dtype == jnp.int32, \ 'images.shape={}'.format(images.shape) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) metrics = p_allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'eval', epoch)) # Wait until computations are done before exiting p_allreduce_metrics(metrics)['accuracy'].block_until_ready() if jax.host_id() == 0: summary_thread.shutdown() if not DONE: mllogger.end('run_stop', metadata={'status': 'aborted'})
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = imagenet_train_utils.create_input_iter(local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = imagenet_train_utils.create_input_iter(local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=False, cache=FLAGS.cache) # Create the hyperparameter object if FLAGS.hparams_config_dict: # In this case, there are multiple training configs defined in the config # dict, so we pull out the one this training run should use. if 'configs' in FLAGS.hparams_config_dict: hparams_config_dict = FLAGS.hparams_config_dict.configs[ FLAGS.config_idx] else: hparams_config_dict = FLAGS.hparams_config_dict hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, hparams_config_dict) else: raise ValueError('Please provide a base config dict.') os_hparams_utils.write_hparams_to_file_with_host_id_check( hparams, FLAGS.model_dir) # get num_epochs from hparam instead of FLAGS num_epochs = hparams.lr_scheduler.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs # Estimate compute / memory costs if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost: estimate_compute_and_memory_cost(image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams) logging.info( 'Writing training HLO and estimating compute/memory costs.') model, variables = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=hparams.model_hparams, train=True) model_state, params = variables.pop('params') if hparams.optimizer == 'sgd': optimizer = optim.Momentum(beta=hparams.momentum, nesterov=True).create(params) elif hparams.optimizer == 'adam': optimizer = optim.Adam(beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params) else: raise ValueError('Optimizer type is not supported.') state = imagenet_train_utils.TrainState(step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del params, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) base_learning_rate = hparams.base_learning_rate * batch_size / 256. learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, hparams.lr_scheduler) p_train_step = jax.pmap(functools.partial( imagenet_train_utils.train_step, model, learning_rate_fn=learning_rate_fn), axis_name='batch', static_broadcasted_argnums=(2, 3)) p_eval_step = jax.pmap(functools.partial(imagenet_train_utils.eval_step, model), axis_name='batch') epoch_metrics = [] state_dict_summary_all = [] state_dict_keys = _get_state_dict_keys_from_flags() t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps: break update_bounds = train_utils.should_update_bounds( hparams.activation_bound_update_freq, hparams.activation_bound_start_step, step) state, metrics = p_train_step(state, batch, hparams, update_bounds) state_dict_summary = summary_utils.get_state_dict_summary( state.model_state, state_dict_keys) state_dict_summary_all.append(state_dict_summary) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() # Write to TensorBoard state_dict_summary_all = common_utils.get_metrics( state_dict_summary_all) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) summary_utils.write_state_dict_summaries_to_tb( state_dict_summary_all, summary_writer, FLAGS.state_dict_summary_freq, step) state_dict_summary_all = [] epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = imagenet_train_utils.sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = imagenet_train_utils.sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() config = FLAGS.config logging.info('===========Config Dict============') logging.info(config) batch_size = config.batch_size learning_rate = config.learning_rate num_train_steps = config.num_train_steps num_eval_steps = config.num_eval_steps eval_freq = config.eval_frequency random_seed = config.random_seed model_type = config.model_type if jax.process_index() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'summary')) else: summary_writer = None if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') logging.info('Training on %s', FLAGS.task_name) if model_type in ['wideresnet', 'resnet', 'simple_cnn']: normalize = True else: # transformer-based models normalize = False (train_ds, eval_ds, test_ds, num_classes, vocab_size, input_shape) = task_registry.TASK_DATA_DICT[FLAGS.task_name]( n_devices=jax.local_device_count(), batch_size=batch_size, normalize=normalize) train_iter = iter(train_ds) model_kwargs = {} flatten_input = True if model_type in ['wideresnet', 'resnet', 'simple_cnn']: model_kwargs.update({ 'num_classes': num_classes, }) flatten_input = False else: # transformer models # we will flatten the input bs, h, w, c = input_shape assert c == 1 input_shape = (bs, h * w * c) model_kwargs.update({ 'vocab_size': vocab_size, 'max_len': input_shape[1], 'classifier': True, 'num_classes': num_classes, }) model_kwargs.update(config.model) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) model, state = train_utils.get_model(model_type, create_model, model_kwargs, init_rng, input_shape) optimizer = create_optimizer(model, learning_rate, config.weight_decay) del model # Don't keep a copy of the initial model. start_step = 0 if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer, state = checkpoints.restore_checkpoint( FLAGS.model_dir, (optimizer, state)) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer and state optimizer = jax_utils.replicate(optimizer) state = jax_utils.replicate(state) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors=config.factors, base_learning_rate=learning_rate, warmup_steps=config.warmup, steps_per_cycle=config.get('steps_per_cycle', None), ) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, num_classes=num_classes, grad_clip_norm=config.get('grad_clip_norm', None), flatten_input=flatten_input), axis_name='batch') p_eval_step = jax.pmap( functools.partial(eval_step, num_classes=num_classes, flatten_input=flatten_input), axis_name='batch', ) optimizer, state, step = train_loop(config, dropout_rngs, eval_ds, eval_freq, num_eval_steps, num_train_steps, optimizer, state, p_eval_step, p_train_step, start_step, train_iter, summary_writer) logging.info('Starting testing') logging.info('====================') test(optimizer, state, p_eval_step, step, test_ds, summary_writer, FLAGS.model_dir)
imagenet_eval, output_dir, metrics) save_model(model, output_dir, method, use_tpu, task_number) def main(unused_argv): logging.info('Base LR: %s.', learning_rate_lib.BASE_LEARNING_RATE) logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy) metrics = ['sparse_categorical_crossentropy'] if FLAGS.eval_top_5_accuracy: metrics.append(sparse_top_k_categorical_accuracy) run(FLAGS.method, FLAGS.output_dir.replace('%task%', str(FLAGS.task)), task_number=FLAGS.task, use_tpu=FLAGS.use_tpu, tpu=FLAGS.tpu, metrics=metrics, fake_data=FLAGS.test_level > 1, fake_training=FLAGS.test_level > 0) if __name__ == '__main__': tf.enable_v2_behavior() # Required due to b/128610213. tf.logging.set_verbosity(tf.logging.INFO) _declare_flags() app.run(main)
def main(argv): del argv # unused arg tf.enable_v2_behavior() tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // batch_size steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) imagenet_train = utils.ImageNetInput(is_training=True, data_dir=FLAGS.data_dir, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) imagenet_eval = utils.ImageNetInput(is_training=False, data_dir=FLAGS.data_dir, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets = { 'clean': strategy.experimental_distribute_datasets_from_function( imagenet_eval.input_fn) } if FLAGS.corruptions_interval > 0: corruption_types, max_intensity = utils.load_corrupted_test_info() for name in corruption_types: for intensity in range(1, max_intensity + 1): dataset_name = '{0}_{1}'.format(name, intensity) corrupt_input_fn = utils.corrupt_test_input_fn( batch_size=FLAGS.per_core_batch_size, corruption_name=name, corruption_intensity=intensity, use_bfloat16=FLAGS.use_bfloat16) test_datasets[dataset_name] = ( strategy.experimental_distribute_datasets_from_function( corrupt_input_fn)) train_dataset = strategy.experimental_distribute_datasets_from_function( imagenet_train.input_fn) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) with strategy.scope(): logging.info('Building Keras ResNet-50 model') model = deterministic_model.resnet50(input_shape=(224, 224, 3), num_classes=NUM_CLASSES) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Scale learning rate and decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 256 learning_rate = utils.LearningRateSchedule(steps_per_epoch, base_lr, FLAGS.train_epochs, _LR_SCHEDULE) optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins) } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, max_intensity + 1): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( ed.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins)) logging.info('Finished building Keras ResNet-50 model') checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the weights. This excludes BN parameters and biases, but # pay caution to their naming scheme. if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) probs = tf.nn.softmax(logits) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator, dataset_name) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types, max_intensity, FLAGS.alexnet_errors_path) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } total_results.update(corrupt_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name)
from absl.testing import parameterized from jax import random as jax_random import numpy as np import tensorflow.compat.v2 as real_tf import tensorflow_probability as tfp from discussion import fun_mcmc from discussion.fun_mcmc import backend from tensorflow_probability.python.internal import test_util as tfp_test_util tf = backend.tf tfb = tfp.bijectors tfd = tfp.distributions util = backend.util real_tf.enable_v2_behavior() def _test_seed(): return tfp_test_util.test_seed() % (2**32 - 1) def _no_compile(fn): return fn def _fwd_mclachlan_optimal_4th_order_step(*args, **kwargs): return fun_mcmc.mclachlan_optimal_4th_order_step( *args, forward=True, **kwargs)
def test_main(): """Entrypoint for tests.""" tf.enable_v2_behavior() tf.test.main()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() run(FLAGS.prediction_path)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join(['%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys())]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype(np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch( batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace( shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap( functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap( functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap( functools.partial( train_lib.initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial(train_lib.predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(4, 5, 6)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, train_rngs = p_train_step( optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: train_lib.pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard(pred_batch) cache = p_init_cache(inputs, outputs, programs) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, eos_token, programs.shape[-1], beam_size) predicted = train_lib.tohost(predicted) inputs, outputs, programs = map(train_lib.tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_score = train_lib.eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info('Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(argv): del argv # Unused if hasattr(tf, 'enable_v2_behavior'): tf.enable_v2_behavior() tf.test.main()
def main(argv): del argv # unused arg tf.enable_v2_behavior() tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) train_input_fn = utils.load_input_fn(split=tfds.Split.TRAIN, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) clean_test_input_fn = utils.load_input_fn( split=tfds.Split.TEST, name=FLAGS.dataset, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) train_dataset = strategy.experimental_distribute_datasets_from_function( train_input_fn) test_datasets = { 'clean': strategy.experimental_distribute_datasets_from_function( clean_test_input_fn), } if FLAGS.corruptions_interval > 0: if FLAGS.dataset == 'cifar10': load_c_input_fn = utils.load_cifar10_c_input_fn else: load_c_input_fn = functools.partial(utils.load_cifar100_c_input_fn, path=FLAGS.cifar100_c_path) corruption_types, max_intensity = utils.load_corrupted_test_info( FLAGS.dataset) for corruption in corruption_types: for intensity in range(1, max_intensity + 1): input_fn = load_c_input_fn( corruption_name=corruption, corruption_intensity=intensity, batch_size=FLAGS.per_core_batch_size, use_bfloat16=FLAGS.use_bfloat16) test_datasets['{0}_{1}'.format(corruption, intensity)] = ( strategy.experimental_distribute_datasets_from_function( input_fn)) ds_info = tfds.builder(FLAGS.dataset).info batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_epoch = ds_info.splits['train'].num_examples // batch_size steps_per_eval = ds_info.splits['test'].num_examples // batch_size num_classes = ds_info.features['label'].num_classes if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building ResNet model') model = wide_resnet(input_shape=ds_info.features['image'].shape, depth=28, width_multiplier=10, num_classes=num_classes, l2=FLAGS.l2, version=2) logging.info('Model input shape: %s', model.input_shape) logging.info('Model output shape: %s', model.output_shape) logging.info('Model number of weights: %s', model.count_params()) # Linearly scale learning rate and the decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 128 lr_decay_epochs = [(start_epoch * FLAGS.train_epochs) // 200 for start_epoch in FLAGS.lr_decay_epochs] lr_schedule = utils.LearningRateSchedule( steps_per_epoch, base_lr, decay_ratio=FLAGS.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=FLAGS.lr_warmup_epochs) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/ece': ed.metrics.ExpectedCalibrationError(num_classes=num_classes, num_bins=FLAGS.num_bins), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': ed.metrics.ExpectedCalibrationError(num_classes=num_classes, num_bins=FLAGS.num_bins), } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} for intensity in range(1, max_intensity + 1): for corruption in corruption_types: dataset_name = '{0}_{1}'.format(corruption, intensity) corrupt_metrics['test/nll_{}'.format(dataset_name)] = ( tf.keras.metrics.Mean()) corrupt_metrics['test/accuracy_{}'.format( dataset_name)] = ( tf.keras.metrics.SparseCategoricalAccuracy()) corrupt_metrics['test/ece_{}'.format(dataset_name)] = ( ed.metrics.ExpectedCalibrationError( num_classes=num_classes, num_bins=FLAGS.num_bins)) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + l2_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.nn.softmax(logits) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) strategy.run(step_fn, args=(next(iterator), )) @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) strategy.run(step_fn, args=(next(iterator), )) train_iterator = iter(train_dataset) start_time = time.time() for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) for step in range(steps_per_epoch): train_step(train_iterator) current_step = epoch * steps_per_epoch + (step + 1) max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) datasets_to_evaluate = {'clean': test_datasets['clean']} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): datasets_to_evaluate = test_datasets for dataset_name, test_dataset in datasets_to_evaluate.items(): test_iterator = iter(test_dataset) logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator, dataset_name) logging.info('Done with testing on %s', dataset_name) corrupt_results = {} if (FLAGS.corruptions_interval > 0 and (epoch + 1) % FLAGS.corruptions_interval == 0): corrupt_results = utils.aggregate_corrupt_metrics( corrupt_metrics, corruption_types, max_intensity) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) total_results = { name: metric.result() for name, metric in metrics.items() } total_results.update(corrupt_results) with summary_writer.as_default(): for name, result in total_results.items(): tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name)
sequence.append(current_word) return sequence def main(argv): del argv sentences = ["<S> hello there <E>", "<S> how are you doing today <E>"] vocab = [ "<S>", "<E>", "hello", "there", "how", "are", "you", "doing", "today" ] module = TextRnnModel(vocab=vocab, emb_dim=10, buckets=100, state_size=128) for _ in range(100): _ = module.train(tf.constant(sentences)) # We have to call this function explicitly if we want it exported, because it # has no input_signature in the @tf.function decorator. decoded = module.decode_greedy( sequence_length=10, first_word=tf.constant("<S>")) _ = [d.numpy() for d in decoded] tf.saved_model.save(module, FLAGS.export_dir) if __name__ == "__main__": tf.enable_v2_behavior() app.run(main)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() stan_model = getattr(targets, FLAGS.target)() with stan_model.sample_fn(sampling_iters=FLAGS.stan_samples, chains=FLAGS.stan_chains, show_progress=True) as mcmc_output: summary = mcmc_output.summary() if FLAGS.print_summary: pd.set_option('display.max_rows', sys.maxsize) pd.set_option('display.max_columns', sys.maxsize) print(mcmc_output.diagnose()) print(summary) array_strs = [] for name, fn in sorted(stan_model.extract_fns.items()): transformed_samples = [] # We handle one chain at a time to reduce memory usage. chain_means = [] chain_stds = [] chain_esss = [] for chain_id in range(FLAGS.stan_chains): # TODO(https://github.com/stan-dev/cmdstanpy/issues/218): This step is # very slow and wastes memory. Consider reading the CSV files ourselves. # sample shape is [num_samples, num_chains, num_columns] chain = mcmc_output.sample[:, chain_id, :] dataframe = pd.DataFrame(chain, columns=mcmc_output.column_names) transformed_samples = fn(dataframe) # We reduce over the samples dimension. Transformations can return # nested outputs. mean = tf.nest.map_structure(lambda s: s.mean(0), transformed_samples) std = tf.nest.map_structure(lambda s: s.std(0), transformed_samples) ess = tf.nest.map_structure(get_ess, transformed_samples) chain_means.append(mean) chain_stds.append(std) chain_esss.append(ess) # Now we reduce across chains. ess = tf.nest.map_structure(lambda *s: np.sum(s, 0), *chain_esss) mean = tf.nest.map_structure(lambda *s: np.mean(s, 0), *chain_means) sem = tf.nest.map_structure(lambda std, ess: std / np.sqrt(ess), std, ess) std = tf.nest.map_structure(lambda *s: np.mean(s, 0), *chain_stds) for (tuple_path, mean_part), sem_part, std_part in zip( nest.flatten_with_tuple_paths(mean), tf.nest.flatten(sem), tf.nest.flatten(std)): array_strs.extend( ground_truth_encoding.save_ground_truth_part( name=name, tuple_path=tuple_path, mean=mean_part, sem=sem_part, std=std_part, sestd=None, )) argv_str = '\n'.join([' {} \\'.format(arg) for arg in sys.argv[1:]]) command_str = ( """bazel run //tools/inference_gym_ground_truth:get_ground_truth -- \ {argv_str}""".format(argv_str=argv_str)) file_str = ground_truth_encoding.get_ground_truth_module_source( target_name=FLAGS.target, command_str=command_str, array_strs=array_strs) if FLAGS.output_directory is None: file_basedir = os.path.dirname(os.path.realpath(__file__)) output_directory = os.path.join( file_basedir, '../../spinoffs/inference_gym/targets/ground_truth') else: output_directory = FLAGS.output_directory file_path = os.path.join(output_directory, '{}.py'.format(FLAGS.target)) print('Writing ground truth values to: {}'.format(file_path)) with open(file_path, 'w') as f: f.write(file_str)