def apply_gradient(self, hyper_params, params, state, grads): grads = jax.tree_map(self._cross_replica_mean, grads) return self.optimizer_def.apply_gradient(hyper_params, params, state, grads)
def test_frozen_dict_maps(self): xs = {'a': 1, 'b': {'c': 2}} frozen = FrozenDict(xs) frozen2 = jax.tree_map(lambda x: x + x, frozen) self.assertEqual(unfreeze(frozen2), {'a': 2, 'b': {'c': 4}})
def update_fn(updates, state, params=None): del params # unused by the test optimizer aggregate_grads = update.apply_updates(state.aggregate_grads, updates) updates = jax.tree_map(lambda u: step_size * u, updates) return updates, TestOptimizerState(aggregate_grads, is_reset=False)
def beam_search_loop_body_fn(state): """Beam search loop state update function.""" # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # --> [batch * beam, 1] flat_ids = flatten_beam_dim(lax.dynamic_slice( state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1))) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree_map( lambda x: flatten_beam_dim(x, batch_size), state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] flat_logits, new_flat_cache = tokens_to_logits( flat_ids, flat_cache, jax.random.PRNGKey(state.cur_index)) # unflatten beam dimension # [batch * beam, vocab] --> [batch, beam, vocab] logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree_map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits candidate_log_probs = jax.nn.log_softmax(logits) # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] log_probs = (candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] # Each item in batch has beam_size * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. beams_to_keep = 2 * beam_size # Flatten beam and vocab dimensions. flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size)) # Gather the top 2*K scores from _all_ beams. # --> [batch, 2*beams], [batch, 2*beams] topk_log_probs, topk_indices = top_k(flat_log_probs, k=beams_to_keep) # Recover the beam index by floor division. topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams and beam-associated caches. # --> [batch, 2*beams, length], {[batch, 2*beams, ...], ...} topk_seq, new_cache = gather_beams([state.live_seqs, new_cache], topk_beam_indices, batch_size, beams_to_keep) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. # --> [batch, 2*beams, 1] topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] topk_seq = lax.dynamic_update_slice( topk_seq, topk_ids, (0, 0, state.cur_index + 1)) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. new_log_probs = topk_log_probs + newly_finished * NEG_INF # --> [batch, beams, length], [batch, beams], {[batch, beams, ...], ...} top_alive_seq, top_alive_log_probs, top_alive_cache = gather_topk_beams( [topk_seq, new_log_probs, new_cache], new_log_probs, batch_size, beam_size) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # Mask out the still unfinished sequences by adding large negative value. # --> [batch, 2*beams] new_scores += (~newly_finished) * NEG_INF # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] [state.finished_seqs, topk_seq], axis=1) finished_scores = jnp.concatenate( # --> [batch, 3*beams] [state.finished_scores, new_scores], axis=1) finished_flags = jnp.concatenate( # --> [batch, 3*beams] [state.finished_flags, newly_finished], axis=1) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( gather_topk_beams([finished_seqs, finished_scores, finished_flags], finished_scores, batch_size, beam_size)) return BeamState(cur_index=state.cur_index + 1, live_logprobs=top_alive_log_probs, finished_scores=top_finished_scores, live_seqs=top_alive_seq, finished_seqs=top_finished_seq, finished_flags=top_finished_flags, cache=top_alive_cache)
def new_weights(self, input_signature): weights = super().new_weights(input_signature) if self.init_checkpoint is None: return weights print('Loading pre-trained weights from', self.init_checkpoint) ckpt = tf.train.load_checkpoint(self.init_checkpoint) def reshape_qkv(name): x = ckpt.get_tensor(name) return x.reshape((x.shape[0], -1, 64)).swapaxes(0, 1) def reshape_o(name): x = ckpt.get_tensor(name) return x.reshape((-1, 64, x.shape[-1])) def reshape_bias(name): x = ckpt.get_tensor(name) return x.reshape((-1, 64)) new_w = [ ckpt.get_tensor('bert/embeddings/word_embeddings'), ckpt.get_tensor('bert/embeddings/token_type_embeddings'), ckpt.get_tensor('bert/embeddings/position_embeddings')[None, ...], ckpt.get_tensor('bert/embeddings/LayerNorm/gamma'), ckpt.get_tensor('bert/embeddings/LayerNorm/beta'), ] for i in range(12): # 12 layers new_w += [ reshape_qkv( f'bert/encoder/layer_{i}/attention/self/query/kernel'), reshape_qkv( f'bert/encoder/layer_{i}/attention/self/key/kernel'), reshape_qkv( f'bert/encoder/layer_{i}/attention/self/value/kernel'), reshape_o( f'bert/encoder/layer_{i}/attention/output/dense/kernel'), reshape_bias( f'bert/encoder/layer_{i}/attention/self/query/bias'), reshape_bias( f'bert/encoder/layer_{i}/attention/self/key/bias'), reshape_bias( f'bert/encoder/layer_{i}/attention/self/value/bias'), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/dense/bias'), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma' ), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'), ckpt.get_tensor( f'bert/encoder/layer_{i}/intermediate/dense/kernel'), ckpt.get_tensor( f'bert/encoder/layer_{i}/intermediate/dense/bias'), ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/kernel'), ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/bias'), ckpt.get_tensor( f'bert/encoder/layer_{i}/output/LayerNorm/gamma'), ckpt.get_tensor( f'bert/encoder/layer_{i}/output/LayerNorm/beta'), ] new_w += [ ckpt.get_tensor('bert/pooler/dense/kernel'), ckpt.get_tensor('bert/pooler/dense/bias'), ] for a, b in zip(fastmath.tree_leaves(weights), new_w): assert a.shape == b.shape, ( f'Expected shape {a.shape}, got shape {b.shape}') weights = jax.tree_unflatten(jax.tree_structure(weights), new_w) move_to_device = jax.jit(lambda x: x) weights = jax.tree_map(move_to_device, weights) return weights
def shard(xs): """Split data into shards for multiple devices along the first dimension.""" return jax.tree_map( lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
def evaluate_sequence_accuracy(p_pred_step, p_init_cache, state, ds, config, split, workdir, num_eval_steps=-1): """Evaluate classification on the given dataset.""" prediction_dir = os.path.join(workdir, 'predictions') tf.io.gfile.makedirs(prediction_dir) logging.info('Starting evaluating sequence accuracy on %s split.', split) outputs = [] test_metrics = collections.defaultdict(list) data_dir = config.dataset.data_dir input_vocab_file = os.path.join(data_dir, 'training_input_vocab.txt') target_vocab_file = os.path.join(data_dir, 'training_target_vocab.txt') dataset_file = os.path.join(data_dir, 'dataset.txt') eos_idx = config.dataset.eos_idx with tf.io.gfile.GFile(input_vocab_file, 'r') as f: input_vocab = json.load(f) with tf.io.gfile.GFile(target_vocab_file, 'r') as f: target_vocab = json.load(f) with tf.io.gfile.GFile(dataset_file, 'r') as f: annotations = json.load(f) for step, batch in enumerate(ds): # pytype: disable=wrong-arg-types batch = jax.tree_map(np.asarray, batch) cache = p_init_cache(batch) batch['predictions'] = p_pred_step(batch, state, cache, eos_idx) batch = remove_pad(tohost(batch)) target_token = batch['target_token'] predictions = batch['predictions'] for i, (prediction, target) in enumerate(zip(predictions, target_token)): prediction = remove_special_tokens(prediction.tolist(), eos_idx) target = remove_special_tokens(target.tolist(), eos_idx) acc = evaluation.sequence_accuracy(prediction, target) test_metrics['test_accuracy'].append(acc) exact_match = 100 if acc == 100 else 0 test_metrics['test_exact_match'].append(exact_match) input_command = remove_special_tokens(batch['token'][i].tolist(), eos_idx) index = int(batch['index'][i][0]) example = annotations['examples'][split][index] outputs.append({ 'split': split, 'index': index, 'input': array_to_sentence(input_command, input_vocab), 'prediction': array_to_sentence(prediction, target_vocab), 'target': array_to_sentence(target, target_vocab), 'derivation': [example['derivation']], 'situation': [example['situation']], 'accuracy': acc, 'exact_match': True if acc == 100 else False, 'attention_weights_input': [], 'attention_weights_situation': [], }) if num_eval_steps > 0 and step + 1 == num_eval_steps: break test_metrics = {k: sum(v) / len(v) for k, v in test_metrics.items()} step = flax_utils.unreplicate(state).step out_path = os.path.join(prediction_dir, f'{split}_predict_{step}.json') with tf.io.gfile.GFile(out_path, 'w') as f: json.dump(outputs, f, indent=2) return test_metrics
def initial_state(self, batch_size): return jax.tree_map(jnp.zeros_like, self._state)
def prepare_tf_data_unbatched(xs): """Prepare TF dataset into unbatched numpy arrays.""" # Use _numpy() for zero-copy conversion between TF and NumPy. # pylint: disable=protected-access return jax.tree_map(lambda x: x._numpy(), xs)
def post_pmap(xs): return jax.tree_map(lambda x: x[0], xs)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys()]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length) indices_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # pylint: disable=unused-variable # Setup DSL # --------------------------------------------------------------------------- # Build token tables. if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']: spec_vocab = robust_fill_dsl.CHARACTER + '|' spec_id_token_table = {i+3: token for i, token in enumerate(spec_vocab)} bos_id = 1 eos_id = 2 spec_id_token_table[bos_id] = robust_fill_dsl.BOS spec_id_token_table[eos_id] = robust_fill_dsl.EOS spec_token_id_table = {token: id for id, token in spec_id_token_table.items()} spec_vocab_size = len(spec_token_id_table) + 1 # For padding. program_id_token_table, program_token_id_table = ( dsl_tokens.build_token_tables()) program_vocab_size = len(program_id_token_table) + 1 # pylint: disable=unused-variable elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" if FLAGS.dataset_type == 'robust_fill': def decode_str(s): """Decode string tokens.""" return ''.join([spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs elif FLAGS.dataset_type == 'scan': def decode_str(s): """Decode string tokens.""" return ' '.join([spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps = [decode_str(inp) for inp in inputs] dummy_outs = [''] * len(inps) return inps, dummy_outs else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) def decode_target(target): """Decode program tokens.""" target = target[:np.argmax(target == eos_id)].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': target = target[target != bos_id].tolist() return ''.join([spec_id_token_table[t_id] for t_id in target if t_id > 0]) elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_id) + 1].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': # Returns either a Concat program object, or None. program = program[program != bos_id].tolist() try: return robust_fill_dsl.decode_program(program, program_id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. elif FLAGS.dataset_type == 'scan': # Returns a string. program = program[jnp.logical_and(program != bos_id, program != eos_id)].tolist() return ' '.join(scan_vocab.decode(program, program_id_token_table)) else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) def decode_program_str(program): # pylint: disable=unused-variable """Decode program tokens into a string.""" decoded = decode_program(program) if FLAGS.dataset_type == 'robust_fill': try: return decoded.to_string() except: # pylint: disable=bare-except return 'did not compile' else: assert isinstance(decoded, str), '{} should be string'.format(decoded) return decoded # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = { 'inputs': io_shape[1:], 'outputs': io_shape[1:], 'spec_parts': target_shape[1:], 'start_index': indices_shape[1:], 'end_index': indices_shape[1:], # TODO(kshi): include programs. # 'program': program_shape[1:], } logging.info('padded_shapes: %s', padded_shapes) if FLAGS.dataset_type == 'robust_fill': create_dataset_fn = input_pipeline.create_robust_fill_dataset_from_tf_record elif FLAGS.dataset_type == 'scan': raise NotImplementedError() # TODO(kshi): Implement. # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) dataset = create_dataset_fn( FLAGS.dataset_filepattern, program_token_id_table, spec_token_id_table, max_target_length=FLAGS.max_target_length) dataset = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = padded_shapes.copy() predict_padded_shapes['inputs'] = predict_io_shape[1:] predict_padded_shapes['outputs'] = predict_io_shape[1:] logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = create_dataset_fn( FLAGS.test_dataset_filepattern, program_token_id_table, spec_token_id_table, max_target_length=FLAGS.max_target_length) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset .take(FLAGS.num_quick_test_steps) .unbatch() .padded_batch(int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset .take(FLAGS.num_final_test_steps) .unbatch() .padded_batch(int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- base_config = base_models.TransformerConfig( vocab_size=spec_vocab_size, output_vocab_size=spec_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_id, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=FLAGS.max_distance, num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=FLAGS.max_distance, num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=FLAGS.max_distance, num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=FLAGS.max_distance, num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=FLAGS.max_distance, bidirectional_program_attention=FLAGS.bidirectional_program_attention) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, dataset_type=FLAGS.dataset_type) eval_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace(deterministic=True), dataset_type=FLAGS.dataset_type) predict_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length)), dataset_type=FLAGS.dataset_type) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) optimizer_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info('Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap( functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap( functools.partial(eval_step, eos_token=eos_id, config=eval_config), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=FLAGS.max_target_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, eos_token=eos_id, max_decode_len=FLAGS.max_target_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4,)) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, targets = load_data(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step( optimizer, inputs, outputs, targets, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, targets = load_data(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, targets) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 ios, targets_list, predictions, top_of_beams = [], [], [], [] for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, targets = load_data(pred_batch) cache = (p_init_cache(inputs, outputs, targets) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, targets = map(tohost, (inputs, outputs, targets)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) ground_truth = decode_target(targets[i]) beams_target = [decode_target(beam) for beam in beams] predicted_target = beams_target[0] for beam_target in beams_target: if beam_target == ground_truth: predicted_target = beam_target total_successes += 1 break total_denominator += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets_list.append(ground_truth) predictions.append(predicted_target) logging.info('') logging.info('ios: %s', ios[-1]) logging.info('targets[%s]: %s', i, targets[i]) logging.info('ground_truth: %s', ground_truth) logging.info('predicted beam: %s', '\n'.join(beams_target)) logging.info('predicted_target: %s', predicted_target) logging.info('beams: %s', beams) if not ground_truth: logging.warn('ground_truth is empty!') top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): top_of_beam.append('index: {}, decoded: {}, tokens: {}'.format( index, decode_target(beam), beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) summary_writer.text('{}-samples-beam-{}'.format(predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush() # Save a Checkpoint. Do this at the end of the training loop, so that if a # worker is descheduled during a round of prediction (which takes a while), # we will redo prediction upon restarting (to avoid losing data). if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step)
def pre_pmap(xs): return jax.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)
def standard_train_step( state, batch, rng_key, dynamic_state, *, static_state, loss_fn, learning_rate_fn, model_cls, grad_clip=None, use_bfloat16=False, parallel=True, vmap_batch=False, ema_decay_rate=0.9, ema_burn_in=1000, threshold=0.0, ): """Perform a single standard training step. Args: state: a TrainState object containing the optimizer and EMA params. batch: dictionary or tuple rng_key: Jax RNG for model Dropout and additional RNG. dynamic_state: a dict of dynamic objects that should be passed to the model. static_state: any additional state to be passed to the model. The model will be recompiled when this changes. loss_fn: loss function which takes a function and batch and returns a loss. learning_rate_fn: function that returns the learning rate for a given iteration. model_cls: an nn.Module type to use for training. Must have a train attr. grad_clip: if not None, a float which determines the grad clipping norm. use_bfloat16: if True, round gradients to bfloat16 during training. parallel: if True, pmean reduces across device dimension. vmap_batch: if True, apply vmap over the batch axis. ema_decay_rate: the rate at which ema stats decay. ema_burn_in: the number of steps to skip before rejecting outliers. threshold: the probability below which any loss sample will be rejected. Set to 0 to disable EMA outlier rejection. Note that this currently doesn't work because we don't update the loss across devices properly. So please do not enable this. Returns: the updated optimizer, a metrics dict, and the new Jax RNG key. """ logging.info('Recompiling train_step.') # only called when recompiling optimizer = state.optimizer # We handle PRNG splitting inside the top pmap to improve efficiency. step = state.step lr = learning_rate_fn(step) model = model_cls(train=True) apply_key, rng_key = jrandom.split(rng_key) loss_key, rng_key = jrandom.split(rng_key) model_apply = utils.make_model_apply(model, apply_key) loss_fn = build_vmapped_loss(loss_fn, batch, loss_key, dynamic_state, is_eval=False, model_apply=model_apply, static_state=static_state, vmap_batch=vmap_batch) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, (metrics, _)), grad = grad_fn(optimizer.target) if use_bfloat16: grad = jax.tree_map(lambda x: x.astype(jnp.bfloat16), grad) if parallel: grad = jax.lax.pmean(grad, 'device') if grad_clip is not None: # Clip gradients after pmean aggregation unclipped_grad = grad grad = jax.experimental.optimizers.clip_grads(grad, grad_clip) new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr) metrics['nn/learning_rate'] = lr # Gradient norms grad_l2_tree = l2_norm(grad) grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree) grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree) metrics['nn/l2_grad_sum'] = grad_l2_sum metrics['nn/l2_grad_max'] = grad_l2_max if grad_clip is not None: # Unclipped gradient norms (if applicable). grad_l2_tree = l2_norm(unclipped_grad) grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree) grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree) metrics['nn/l2_noclip_grad_sum'] = grad_l2_sum metrics['nn/l2_noclip_grad_max'] = grad_l2_max if threshold > 0: normal_pdf = jax.scipy.stats.norm.pdf(loss, loc=state.ema_loss, scale=jnp.sqrt( state.ema_variance)) metrics['nn/normal_pdf'] = normal_pdf should_replace = (normal_pdf > threshold) | (state.step < ema_burn_in) else: should_replace = True grads_ok = jnp.all( jnp.asarray([ jnp.all(jnp.isfinite(p)) for p in jax.tree_leaves(new_optimizer) ])) loss_ok = jnp.all(jnp.isfinite(loss)) should_replace = should_replace & grads_ok & loss_ok metrics['nn/step_skipped'] = 1 - should_replace metrics['nn/ema_loss'] = state.ema_loss metrics['nn/ema_variance'] = state.ema_variance metrics['nn/step'] = state.step metrics['nn/grads_ok'] = grads_ok metrics['nn/loss_ok'] = loss_ok delta = (loss - state.ema_loss) new_state = TrainState( optimizer=new_optimizer, step=state.step + 1, ema_loss=state.ema_loss * ema_decay_rate + (1 - ema_decay_rate) * loss, ema_variance=state.ema_variance * ema_decay_rate + (1 - ema_decay_rate) * delta**2, ) new_state = jax.tree_multimap( lambda a, b: jnp.where(should_replace, a, b), new_state, state, ) return new_state, metrics, rng_key
def state_dict(self, target, state): state_dict = self.optimizer_def.state_dict(target, state) # only the first copy of the parameters and optimizer state are stored. state_dict = jax.tree_map(lambda x: x[0], state_dict) return state_dict
def unreplicate(tree): """Returns a single instance of a replicated array.""" return jax.tree_map(lambda x: x[0], tree)
def same_tree_with_value(self, tree, value): return jax.tree_map(lambda x: jnp.ones_like(x) * value, tree)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) batch_size = config.batch_size n_devices = jax.device_count() if jax.host_count() > 1: raise ValueError( 'PixelCNN++ example should not be run on more than 1 host' ' (for now)') if batch_size % n_devices > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_summary_writer, eval_summary_writer = get_summary_writers(workdir) # Load dataset data_source = input_pipeline.DataSource(config) train_ds = data_source.train_ds eval_ds = data_source.eval_ds steps_per_epoch = data_source.ds_info.splits[ 'train'].num_examples // config.batch_size # Create dataset batch iterators train_iter = iter(train_ds) num_train_steps = train_ds.cardinality().numpy() steps_per_checkpoint = 1000 # Create the model using data-dependent initialization. Don't shard the init # batch. assert config.init_batch_size <= batch_size init_batch = next(train_iter)['image']._numpy()[:config.init_batch_size] rng = jax.random.PRNGKey(config.seed) rng, init_rng, dropout_rng = jax.random.split(rng, 3) initial_variables = model(config).init( { 'params': init_rng, 'dropout': dropout_rng }, init_batch)['params'] optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(initial_variables) optimizer, ema = restore_checkpoint(workdir, optimizer, initial_variables) ema = initial_variables step_offset = int(optimizer.state.step) optimizer, ema = jax_utils.replicate((optimizer, ema)) # Learning rate schedule learning_rate_fn = lambda step: config.learning_rate * config.lr_decay**step # pmap the train and eval functions p_train_step = jax.pmap(functools.partial(train_step, config, learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, config=config), axis_name='batch') # Gather metrics train_metrics = [] for step, batch in zip(range(step_offset, num_train_steps), train_iter): # Load and shard the TF batch batch = load_and_shard_tf_batch(batch) # Generate a PRNG key that will be rolled into the batch. rng, step_rng = jax.random.split(rng) sharded_rngs = common_utils.shard_prng_key(step_rng) # Train step optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch # We've finished an epoch train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) # Send stats to Tensorboard for key, vals in train_metrics.items(): for i, val in enumerate(vals): train_summary_writer.scalar(key, val, step - len(vals) + i + 1) # Reset train metrics train_metrics = [] # Evaluation eval_metrics = [] for eval_batch in eval_ds: # Load and shard the TF batch eval_batch = load_and_shard_tf_batch(eval_batch) # Step metrics = p_eval_step(ema, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) # Log epoch summary logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch, train_summary['loss'], eval_summary['loss']) eval_summary_writer.scalar('loss', eval_summary['loss'], step) train_summary_writer.flush() eval_summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_train_steps: save_checkpoint(workdir, optimizer, ema, step)
attn = Attn(attn_module=self.attn_module, qkv_features=qkv_features // self.num_heads, out_features=out_features) # evaluate multi-headed-attention. y = attn(inputs_q, inputs_kv, bias) return y.mean(axis=-2) # run it. if __name__ == '__main__': inputs = jnp.ones((8, 97, 256)) rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} model = MultiHeadDotProductAttention( broadcast_dropout=False, qkv_features=256, out_features=256, attn_module=functools.partial(SoftmaxAttnWDropout, rate=0.1), num_heads=8, batch_axes=(0, ), ) y, params = model.init_with_output(rngs, inputs, inputs) print('input shape: ', inputs.shape) print('parameter shapes:') pprint(jax.tree_map(jnp.shape, unfreeze(params))) print('output shape: ', y.shape)
def to_device(xs): """Transfer data to devices (GPU/TPU).""" return jax.tree_map(jnp.array, xs)
return x class gMLPModel(nn.Module): ffn_dim: int model_dim: int num_blocks: int @nn.compact def __call__(self, x): x = nn.Dense(name='embedding', features=self.model_dim)(x) for i in range(self.num_blocks): x = gMLPBlock(ffn_dim=self.ffn_dim, model_dim=self.model_dim)(x) return x tiny_settings = {'ffn_dim': 768, 'model_dim': 128, 'num_blocks': 30} if __name__ == "__main__": key = random.PRNGKey(2) x = random.normal(key, shape=[8, 12, 18, 100]) model = gMLPModel(**tiny_settings) model_state = model.init(key, x) y = model.apply(model_state, x) print( json.dumps(jax.tree_map(np.shape, flax.core.unfreeze(model_state['params'])), indent=2)) num_params = functools.reduce( operator.add, map(np.size, jax.tree_leaves(model_state['params'])))
def remove_pad(x): """Remove padding examples.""" if 'mask' in x: ind = jnp.where(jnp.array(x.pop('mask'), dtype=jnp.int32) > 0) x = jax.tree_map(lambda v: v[ind], x) # pylint: disable=cell-var-from-loop return x
def _wrapped(*args): return jax.tree_map(pfn, *args)
def __init__(self, model, weights, max_decode_len, beam_size=1, temperature=0, alpha=0.0, eos_id=None): """Construct an inference wrapper for an autoregressive model. The default behavior is to do greedy decoding: s = Search(model, weights, max_decode_len, eos_id=eos_id) Passing a temperature parameter will switch to sampling: s = Search(model, weights, max_decode_len, temperature=1, eos_id=eos_id) Passing a beam_size parameter will switch to beam search. For machine translation with Transformer models, Vaswani et al. (2017) recommend a beam size of 4 and length normalization with alpha=0.6. s = Search(model, weights, max_decode_len, beam_size=4, alpha=0.6, eos_id=eos_id) After constructing the class, see Search.decode for how to decode a batch of examples. Args: model: function to construct a model (e.g. trax.models.Reformer) weights: model weights max_decode_len: maximum length to decode beam_size: beam size, for beam search temperature: temperature parameter for sampling; set to nonzero to switch from greedy/beam-search behavior to sampling. alpha: length penalty alpha coefficient for beam search. eos_id: end-of-sentence token for target vocabulary. """ # TODO(kitaev): k and p parameters for top-k and nucleus sampling. self.model = model self.model_infer = model(mode='predict') # Weights are stored on device, but not replicated. self.model_weights = jax.tree_map(jax.jit(lambda x: x), weights) self.sample = (temperature != 0) self.temperature = temperature if self.sample and beam_size > 1: # TODO(kitaev): perform stochastic beam search in this case # (https://arxiv.org/abs/1903.06059) raise ValueError('beam_size parameter is not supported when sampling') is_cache = [isinstance(l, tl.Cache) for l in self.model_infer.sublayers] if any(is_cache): assert sum([int(x) for x in is_cache]) == 1, ( 'At most one usage of tl.Cache currently supported') self.encoder_idx = is_cache.index(True) + 1 else: self.encoder_idx = None beam_search_partial = functools.partial( self._unreplicated_beam_search, beam_size=beam_size, alpha=alpha, eos_token=eos_id if eos_id is not None else -1, max_decode_len=max_decode_len + 1) # Add 1 to account for start token. if trax.math.device_count() == 1: self._jit_beam_search = jax.jit(beam_search_partial, static_argnums=(2,)) else: self._jit_beam_search = jax.pmap(beam_search_partial, axis_name='batch', static_broadcasted_argnums=(2,)) # Work around a jax error # Ref: https://github.com/google/jax/issues/1919#issuecomment-569985681 jax_partial_eval._thread_local_state.remat = True # pylint: disable=protected-access
def init_fn(key, sx=sx): return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx)
def batch_loss(params, rng): stuff = jax.vmap( lambda rng: self.loss_and_metrics_one_pair(params, rng))( jax.random.split(rng, self.batch_size)) return jax.tree_map(jnp.mean, stuff)
def apply_fn(params, x): return jax.tree_map(lambda p, v: p + v, params, x)
def init_fn(params): aggregate_grads = jax.tree_map(jnp.zeros_like, params) return TestOptimizerState(aggregate_grads, is_reset=True)
def grad_expect_operator_Lrho2( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples_node = σ.shape[0] has_aux = mutable is not False # if not has_aux: # out_axes = (0, 0) # else: # out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] # local_kernel_vmap = jax.vmap( # partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 # ) # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1)) ( Lρ, der_loc_vals, ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ) # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated Lρ = jnp.conjugate(Lρ) LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T) LdagL_mean = LdagL_stats.mean _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters) # TODO: this ones_like might produce a complexXX type but we only need floatXX # and we cut in 1/2 the # of operations to do. der_logs_ave = d_logpsi( jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes) )[0] der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave) def gradfun(der_loc_vals, der_logs_ave): par_dims = der_loc_vals.ndim - 1 _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims))) grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - ( der_logs_ave.conjugate() * LdagL_mean ) return grad LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave) # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away # the imaginary part of the gradient. # we do this also for standard gradient of energy. # this avoid errors in #867, #789, #850 LdagL_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), LdagL_grad, parameters, ) return ( LdagL_stats, LdagL_grad, model_state, )
def _masked_sgd_on_updates(m, upd): return jax.tree_map(lambda x: -x, upd) if m else upd
def init_state(self, params): param_states = jax.tree_map(self.init_param_state, params) state = OptimizerState(0, param_states) return state