def run_eval(run_configuration): """Evaluates on checkpoints as they become available.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter optimizer = adapter.create_optimizer(run_configuration) last_checkpoint_path = None last_checkpoint_time = time.time() checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) success_path = checkpoint_utils.build_success_path(run_dir) error_count = 0 while True: success = tf.io.gfile.exists(success_path) checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) if checkpoint_path is not None and checkpoint_path != last_checkpoint_path: logging.info('Evaluating with checkpoint_path: %s', checkpoint_path) try: eval_once(run_configuration, checkpoint_path, optimizer) except: # pylint: disable=bare-except logging.info('Could not evaluate %s', checkpoint_path) error_count += 1 if error_count >= 10 or config.debug: raise last_checkpoint_path = checkpoint_path last_checkpoint_time = time.time() else: if success: logging.info('SUCCESS file found. Stopping.') break if time.time() - last_checkpoint_time > config.eval_timeout: logging.info('Timed out waiting for checkpoint. Stopping.') break logging.info('Waiting for checkpoint.') time.sleep(15)
def run_train_single_device(run_configuration): """Runs the training workflow without pmap or jit.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter checkpoint_path = run_configuration.original_checkpoint_path dataset = run_configuration.dataset_info.dataset random_seed = 0 rng = jax.random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) dropout_rng, init_rng = jax.random.split(rng) # Set up optimizer. optimizer = adapter.create_optimizer(run_configuration, rng=init_rng) # Set up train step. train_step = adapter.make_train_step(single_device=True) # Set up checkpointing. # TODO(dbieber): Set up phoenix. checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) if checkpoint_path is None: checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) optimizer = checkpoint_utils.handle_restart_behavior( checkpoint_path, optimizer, config) start_step = int(optimizer.state.step) num_train_steps = config.train.total_steps # Begin training loop. dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw, single_device=True) for step, example in zip(range(start_step, num_train_steps), dataset_iter): print(f'Step #{step}') train_inputs = adapter.get_train_inputs(example) optimizer, metrics, dropout_rng, logits, state = train_step( optimizer, train_inputs, dropout_rng) del metrics, logits, state # Unused. # Save a Checkpoint. if ((step % config.logging.save_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.logging.save_freq: # Save unreplicated optimizer + model state. checkpoint_utils.save_checkpoint(checkpoint_dir, optimizer, step)
def run_train(run_configuration): """Runs the training workflow.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter log_dir = os.path.join(run_dir, 'train') checkpoint_path = run_configuration.original_checkpoint_path dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info random_seed = 0 rng = jax.random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rngs = jax.random.split(rng, jax.local_device_count()) # Set up optimizer. optimizer = adapter.create_optimizer(run_configuration, rng=init_rng) # Set up train step. train_step = adapter.make_train_step() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Set up checkpointing. # TODO(dbieber): Set up phoenix. checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) if checkpoint_path is None: checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) optimizer = checkpoint_utils.handle_restart_behavior( checkpoint_path, optimizer, config) start_step = int(optimizer.state.step) num_train_steps = config.train.total_steps # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) # Begin training loop. dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) summary_freq = config.logging.summary_freq metrics_all = [] tick = time.time() for step, example in zip(range(start_step, num_train_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) optimizer, metrics, dropout_rngs, logits, state = train_step( optimizer, train_inputs, dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % config.logging.save_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.logging.save_freq: # Save unreplicated optimizer + model state. checkpoint_utils.save_checkpoint( checkpoint_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if summary_freq and step % summary_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = summary_freq / (tock - tick) examples_per_sec = denominator / (tock - tick) tick = tock summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = []
def configure(data_dir, run_dir, config, override_values, xm_parameters=None): """Sets up the Learned Interpreter code with the specified configuration.""" seed() # Apply any overrides set at the command line or in the launcher. if config.overrides != config.default_overrides: logging.info('Applying overrides set at command line: %s', config.overrides) overrides_lib.apply_overrides( config, override_names=config.overrides.split(',')) config.update_from_flattened_dict(override_values) # If a checkpoint is specified, it determines the "original run." # Otherwise the run_dir, if already present, determines the "original run." config_filepath = os.path.join(run_dir, 'config.json') if checkpoint_utils.is_checkpoint_specified(config.checkpoint): original_checkpoint_path = checkpoint_utils.get_specified_checkpoint_path( run_dir, config.checkpoint) original_run_dir = checkpoint_utils.get_run_dir( original_checkpoint_path) original_config_filepath = os.path.join(original_run_dir, 'config.json') else: checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) original_checkpoint_path = checkpoint_utils.latest_checkpoint( checkpoint_dir) original_config_filepath = config_filepath original_config_exists = gfile.exists(original_config_filepath) # Handle any existing configs. if original_config_exists: original_config = config_utils.load_config(original_config_filepath) # Handle the model config. if config.runner.model_config == 'load': logging.info('Loading the model config from %s', original_config_filepath) config.model.update(original_config.model) config.dataset.representation = original_config.dataset.representation elif config.runner.model_config == 'assert': same_config = config_utils.equals(config.model, original_config.model) # Resolution: # Either use a new run_dir, or set model_config to 'load' or 'keep'. assert same_config, 'Model config has changed.' else: assert config.runner.model_config == 'keep' # Handle the dataset config. if config.runner.dataset_config == 'load': logging.info('Loading the data config from %s', original_config_filepath) config.dataset.update(original_config.dataset) elif config.runner.dataset_config == 'assert': same_config = config_utils.equals(config.dataset, original_config.dataset) assert same_config, 'Dataset config has changed.' else: assert config.runner.dataset_config == 'keep' elif (config.runner.model_config == 'load' or config.runner.dataset_config == 'load'): raise ValueError('Original model config not found.') # In interactive mode, force batch size 1. if config.runner.mode == 'interact': config.dataset.batch_size = 1 config_exists = gfile.exists(config_filepath) if not config_exists and config.runner.mode in 'train': gfile.makedirs(run_dir) config_utils.save_config(config, config_filepath) # Load dataset. if config.setup.setup_dataset: dataset_info = dataset_utils.get_dataset(data_dir, config) info = dataset_info.info else: dataset_info = None info = None # Create model. if config.setup.setup_model: model = models_lib.get_model(info, config) else: model = None adapter = adapters_lib.get_default_adapter(info, config) return RunConfiguration(mode=config.runner.mode, method=config.runner.method, run_dir=run_dir, data_dir=data_dir, original_checkpoint_path=original_checkpoint_path, model=model, info=info, config=config, adapter=adapter, dataset_info=dataset_info)
def run_eval_multi_dataset(run_configuration): """Evaluates on checkpoints as they become available.""" config = run_configuration.config run_dir = run_configuration.run_dir data_dir = run_configuration.data_dir adapter = run_configuration.adapter optimizer = adapter.create_optimizer(run_configuration) eval_dataset_names = config.launcher.eval_dataset_names.split(',') dataset_infos = [dataset_utils.get_dataset(data_dir, config, name) for name in eval_dataset_names] all_dataset_ids = set(range(len(dataset_infos))) dataset_ids_evaluated = set() last_dataset_id = -1 last_checkpoint_path = None last_checkpoint_time = time.time() checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) success_path = checkpoint_utils.build_success_path(run_dir) error_count = 0 while True: success = gfile.exists(success_path) # Always evaluate at the latest checkpoint. checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) # Choose the dataset to evaluate. if checkpoint_path is not None and checkpoint_path != last_checkpoint_path: # The dataset ids evaluated at the latest checkpoint. # Our goal is to evaluate all the datasets at the final checkpoint as soon # as possible, while providing best effort progress updates along the way. dataset_ids_evaluated = set() dataset_id = (last_dataset_id + 1) % len(dataset_infos) if (dataset_id in dataset_ids_evaluated and dataset_ids_evaluated != all_dataset_ids): dataset_id = next(iter(all_dataset_ids - dataset_ids_evaluated)) if ((checkpoint_path, dataset_id) != (last_checkpoint_path, last_dataset_id) and dataset_ids_evaluated != all_dataset_ids and checkpoint_path is not None): logging.info('Evaluating with checkpoint_path: %s', checkpoint_path) logging.info('Evaluating on dataset id: %d', dataset_id) run_configuration.dataset_info = dataset_infos[dataset_id] try: workflows.eval_once(run_configuration, checkpoint_path, optimizer) except: # pylint: disable=bare-except logging.info('Could not evaluate %s on dataset %d', checkpoint_path, dataset_id) error_count += 1 if error_count >= 10 or config.debug: raise last_dataset_id = dataset_id dataset_ids_evaluated.add(dataset_id) last_checkpoint_path = checkpoint_path last_checkpoint_time = time.time() else: if success: logging.info('SUCCESS file found. Stopping.') break if time.time() - last_checkpoint_time > config.eval_timeout: logging.info('Timed out waiting for checkpoint. Stopping.') break logging.info('Waiting for checkpoint.') time.sleep(15)