def fit(self, xs, epochs=1, batch_size=None, max_steps=10**6): """Fits to sequences given as [N x length] token array.""" if batch_size is None: batch_size = self._batch_size if hasattr(xs, 'as_numpy_iterator'): # TF Dataset ds = xs.repeat(epochs) num_train_steps = max_steps elif hasattr(xs, 'element_spec'): # Dataset iterator. if epochs != 1: raise ValueError('Epochs must == 1 when using iterator input.') ds = xs num_train_steps = max_steps else: # Raw sequences which we turn into a dataset. ds = data.dataset_from_tensors(xs) ds = ds.shuffle(buffer_size=1024).repeat().batch(batch_size) num_train_steps = math.ceil((len(xs) * epochs) / float(batch_size)) if max_steps: num_train_steps = min(num_train_steps, max_steps) if not num_train_steps: raise ValueError('Must set max_steps to nonzero value.') metrics = [] start = time.time() max_steps = max_steps or 10**6 for _, batch in zip(range(num_train_steps), ds): metrics.append(self.fit_batch(batch)) finish = time.time() average = evaluation.combine_metrics(metrics) average['runtime'] = finish - start average['rate'] = len(metrics) / (finish - start) if self._store_metrics: average = tree.map_structure(onp.array, average) self._epoch_train.append(average) return dict(last=evaluation.combine_metrics([metrics[-1]]), average=average)
def test_tag_attention(self): # TODO(ddohan): Consider making decorator which tracks tensor distributions. def tagging_dot_product_attention(query, key, value, **kwargs): modules.Tag(jnp.mean(query), name='mean_query') modules.Tag(jnp.mean(query), name='mean_key') modules.Tag(jnp.mean(query), name='mean_value') return modules.nn.attention.dot_product_attention(query, key, value, **kwargs) domain = _test_domain() lm = models.FlaxModel( domain=domain, attention_fn=tagging_dot_product_attention, **lm_cfg) xs = domain.sample_uniformly(4) metrics = [] for _ in range(2): step_metrics = lm.fit_batch((xs, xs)) metrics.append(step_metrics) combined_metrics = evaluation.combine_metrics(metrics) # Confirm metrics are included. self.assertIn('nn/1/1/mean_query', combined_metrics) self.assertIn('nn/1/1/mean_key', combined_metrics) self.assertIn('nn/1/1/mean_value', combined_metrics)
def run_experiment( model_dir, data_dir=None, xid=None, batch_size_per_device=128, eval_frequency=500, checkpoint_frequency=10000, save_checkpoints=True, restore_checkpoint=True, num_eval_steps=None, epochs=None, max_train_steps=1000000, # 1 million max_train_length=512, train_summary_frequency=100, max_eval_length=None, model_cls=models.FlaxLM): """Run experiment. Args: model_dir: Directory to save checkpoints and metrics to. data_dir: Directory to load data. xid: Optional experiment id. batch_size_per_device: Batch size per device. eval_frequency: Steps per eval. checkpoint_frequency: How often to checkpoint. If None, only checkpoint once at end of run. save_checkpoints: If True, checkpoints model according to checkpoint_frequency restore_checkpoint: If True, will restore checkpoint from directory. Useful for robustness to preemption. num_eval_steps: Number of eval steps to take on eval dataset. epochs: Number of train epochs. max_train_steps: Stop training after N steps. max_train_length: Crop training sequences to this length. train_summary_frequency: Frequency to write train metrics. max_eval_length: Maximum eval length. Defaults to max_train_length. model_cls: Model class to use. Returns: FlaxLM resulting from running training. """ if xid is not None: model_dir = os.path.join(model_dir, '%s_l%s' % (str(xid), max_train_length)) tf.enable_v2_behavior() if jax.host_id() == 0: summary_writer = tf_summary.create_file_writer(os.path.join( model_dir, 'metrics'), max_queue=1, flush_millis=1000) train_summary_writer = logging_lib.ScalarSummary(step=None, scope='train/', enable_tf=True, verbose=0) eval_summary_writer = logging_lib.ScalarSummary(step=None, scope='eval/', enable_tf=True, verbose=0) batch_size = batch_size_per_device * jax.local_device_count() max_eval_length = max_eval_length or max_train_length train_files, test_files = data.get_train_valid_files(directory=data_dir) train_ds, eval_ds = data.load_dataset(train_files=train_files, test_files=test_files, batch_size=batch_size, max_train_length=max_train_length, max_eval_length=max_eval_length, shuffle_buffer=16384) with contextlib.ExitStack() as stack: # pylint: disable=using-constant-test if jax.host_id() == 0: # Only need metric writer context manager on host 0. stack.enter_context(summary_writer.as_default()) model = model_cls(domain=data.protein_domain, batch_size=batch_size) if restore_checkpoint: try: model.load_checkpoint(model_dir) except ValueError: # No checkpoint to load -> raises ValueError. pass start_step = model.train_step train_ds = train_ds.repeat(epochs) train_iter = iter(train_ds) train_metrics = [] tick = time.time() if jax.host_id() == 0: _write_gin_configs(os.path.join(model_dir, 'config.gin')) num_evals = 0 for step, batch in zip(range(start_step, max_train_steps), train_iter): batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access metrics = model.fit_batch(batch) train_metrics.append(metrics) if jax.host_id() == 0 and ( (save_checkpoints and checkpoint_frequency and step % checkpoint_frequency == 0 and step > 0) or step == max_train_steps - 1): model.save_checkpoint(model_dir) if (step + 1) % train_summary_frequency == 0: summary = evaluation.combine_metrics(train_metrics) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_frequency / (tock - tick) tick = tock train_summary_writer('steps per second', steps_per_sec, step) for key, val in summary.items(): if jnp.isnan(val): raise ValueError(f'NaN in {key} at step {step}.') train_summary_writer(key, val, step) # reset metric accumulation for next evaluation cycle. train_metrics = [] if eval_frequency and (step + 1) % eval_frequency == 0: eval_summary = evaluation.evaluate( model=model, eval_ds=eval_ds, num_eval_steps=num_eval_steps) logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer(key, val, step) tf_summary.flush() summary_writer.flush() if num_evals == 0: # Write out config on first eval. _write_gin_configs( os.path.join(model_dir, 'config_after_eval.gin')) num_evals += 1 if jax.host_id() == 0: tf_summary.flush() summary_writer.close() _write_gin_configs(os.path.join(model_dir, 'config_end.gin')) return model