def setup_output_dir(base_dir): """Create a new directory for saving script output. Generation info is automatically saved to 'geninfo.pbtxt'. Args: base_dir: string giving the base directory to which to place this directory. Returns: String giving the path of the new directory. This is a sub-directory of base_dir with a name given by a UTC timestamp. """ save_dir = os.path.join(base_dir, datetime.datetime.utcnow().isoformat()) # tensorflow doesn't like : in filenames save_dir = save_dir.replace(':', '-') logging.info('creating output directory %s', save_dir) gfile.MakeDirs(save_dir) generationinfo.to_file(os.path.join(save_dir, 'geninfo.pbtxt')) return save_dir
def run_training(hps, experiment_proto, train_dir, train_input_paths, val_input_paths, tuner=None, master='', metrics_targets=None, metrics_measures=None): """Main training function. Trains the model given a directory to write to and a logfile to write to. Args: hps: tf.HParams with training parameters. experiment_proto: selection_pb2.Experiment proto for training. train_dir: str path to train directory. train_input_paths: List[str] giving paths to input sstables for training. val_input_paths: List[str] giving paths to input sstable(s) for validation. tuner: optional hp_tuner.HPTuner. master: optional string to pass to a tf.Supervisor. metrics_targets: String list of network targets to report metrics for. metrics_measures: Measurements about the performance of the network to report, e.g. 'auc/top_1p'. Returns: None. Raises: Error: if the hyperparamter combination in hps is infeasible and there is no tuner. (If the hyperparameter combination is infeasible and there is a tuner then the params are reported back to the tuner as infeasible.) """ hps_infeasible, infeasible_reason = hps_is_infeasible( hps, experiment_proto.sequence_length) if hps_infeasible: if tuner: tuner.report_done(True, infeasible_reason) logger.info('report_done(infeasible=%r)', hps_infeasible) return else: raise Error('Hyperparams are infeasible: %s', infeasible_reason) logger.info('Starting training.') if tuner: logger.info('Using tuner: loaded HParams from Vizier') else: logger.info('No tuner: using default HParams') logger.info('experiment_proto: %s', experiment_proto) logger.info('train_dir: %s', train_dir) logger.info('train_input_paths[0]: %s', train_input_paths[0]) logger.info('val_input_paths[0]: %s', val_input_paths[0]) logger.info('%r', list(hps.values())) generationinfo.to_file(os.path.join(train_dir, 'geninfo.pbtxt')) with gfile.Open(os.path.join(train_dir, config.hparams_name), 'w') as f: f.write(str(hps.to_proto())) eval_size = hps.eval_size or None def make_subdir(subdirectory_mame): path = os.path.join(train_dir, subdirectory_mame) gfile.MakeDirs(path) return path logger.info('Computing preprocessing statistics') # TODO(shoyer): move this over into preprocessing instead? experiment_proto = dataset_stats.compute_experiment_statistics( experiment_proto, train_input_paths, os.path.join( hps.input_dir, six.ensure_str( config.wetlab_experiment_train_pbtxt_path[hps.val_fold]) + '.wstats'), preprocess_mode=hps.preprocess_mode, max_size=eval_size, logdir=make_subdir('compute-statistics'), save_stats=hps.save_stats) logging.info('Saving experiment proto with statistics') with gfile.Open( os.path.join(train_dir, config.wetlab_experiment_train_name), 'w') as f: f.write(str(experiment_proto)) logger.debug(str(hps.to_proto())) logger.debug(hps.run_name) tr_entries = len(sstable.MergedSSTable(train_input_paths)) logger.info('Training sstable size: %d', tr_entries) val_entries = len(sstable.MergedSSTable(val_input_paths)) logger.info('Validation sstable size: %d', val_entries) epoch_size = hps.epoch_size or int(tr_entries * (1 + hps.ratio_random_dna)) num_batches_per_epoch = int(float(epoch_size) / hps.mbsz) eval_ff.config_pandas_display(FLAGS.interactive_display) tr_evaluator = eval_ff.Evaluator( hps, experiment_proto, train_input_paths, make_subdir(config.experiment_training_dir), verbose=FLAGS.verbose_eval) val_evaluator = eval_ff.Evaluator( hps, experiment_proto, val_input_paths, make_subdir(config.experiment_validation_dir), verbose=FLAGS.verbose_eval) with tf.Graph().as_default(): # we need to use the registered key 'hparams' tf.add_to_collection('hparams', hps) # TODO(shoyer): collect these into a Model class: dummy_inputs = data.dummy_inputs( experiment_proto, input_features=hps.input_features, kmer_k_max=hps.kmer_k_max, additional_output=six.ensure_str(hps.additional_output).split(',')) output_layer = output_layers.create_output_layer(experiment_proto, hps) net = ff.FeedForward(dummy_inputs, output_layer.logit_axis, hps) trainer = FeedForwardTrainer(hps, net, output_layer, experiment_proto, train_input_paths) summary_writer = tf.SummaryWriter(make_subdir('training'), flush_secs=30) # TODO(shoyer): file a bug to figure out why write_version=2 (now the # default) doesn't work. saver = tf.Saver(write_version=1) # We are always the chief since we do not do distributed training. # Every replica with a different task id is completely independent and all # must be their own chief. sv = tf.Supervisor( logdir=train_dir, is_chief=True, summary_writer=summary_writer, save_summaries_secs=10, save_model_secs=180, saver=saver) logger.info('Preparing session') train_report_dir = os.path.join(train_dir, config.experiment_training_dir) cur_train_report = os.path.join(train_report_dir, config.experiment_report_name) best_train_report = os.path.join(train_report_dir, config.experiment_best_report_name) valid_report_dir = os.path.join(train_dir, config.experiment_validation_dir) cur_valid_report = os.path.join(valid_report_dir, config.experiment_report_name) best_valid_report = os.path.join(valid_report_dir, config.experiment_best_report_name) best_checkpoint = os.path.join(train_dir, 'model.ckpt-lowest_val_loss') best_checkpoint_meta = best_checkpoint + '.meta' best_epoch_file = os.path.join(train_dir, 'best_epoch.txt') with sv.managed_session(master) as sess: logger.info('Starting queue runners') sv.start_queue_runners(sess) def save_and_evaluate(): """Save and evaluate the current model. Returns: path: the path string to the checkpoint. summary_df: pandas.DataFrame storing the evaluation result on the validation dataset with rows for each output name and columns for each metric value """ logger.info('Saving model checkpoint') path = sv.saver.save( sess, sv.save_path, global_step=sv.global_step, write_meta_graph=True) tr_evaluator.run(path, eval_size) summary_df, _ = val_evaluator.run_and_report( tuner, path, eval_size, metrics_targets=metrics_targets, metrics_measures=metrics_measures) return path, summary_df def update_best_model(path, cur_epoch): """Update the records of the model with the lowest validation error. Args: path: the path to the checkpoint of the current model. cur_epoch: a integer of the current epoch """ cur_checkpoint = path cur_checkpoint_meta = six.ensure_str(cur_checkpoint) + '.meta' gfile.Copy(cur_train_report, best_train_report, overwrite=True) gfile.Copy(cur_valid_report, best_valid_report, overwrite=True) gfile.Copy(cur_checkpoint, best_checkpoint, overwrite=True) gfile.Copy(cur_checkpoint_meta, best_checkpoint_meta, overwrite=True) with gfile.Open(best_epoch_file, 'w') as f: f.write(str(cur_epoch)+'\n') def compare_with_best_model(checkpoint_path, summary_df, cur_epoch): logger.info('Comparing current val loss with the best model') if not gfile.Exists(best_train_report): logger.info('No best model saved. Adding current model...') update_best_model(checkpoint_path, cur_epoch) else: with gfile.GFile(best_valid_report) as f: with xarray.open_dataset(f) as best_ds: best_ds.load() cur_loss = summary_df['loss'].loc['mean'] best_loss = best_ds['loss'].mean('output') logger.info('Current val loss:%f', cur_loss) logger.info('The best val loss:%f', best_loss) if cur_loss < best_loss: logger.info( 'Current model has lower loss. Updating the best model.') update_best_model(checkpoint_path, cur_epoch) else: logger.info('The best model has lower loss.') logger.info('Running eval before starting training') save_and_evaluate() try: for cur_epoch in trainer.train(sess, hps.epochs, num_batches_per_epoch): checkpoint_path, val_summary_df = save_and_evaluate() if (cur_epoch+1) % hps.epoch_interval_to_save_best == 0: compare_with_best_model(checkpoint_path, val_summary_df, cur_epoch) if tuner and tuner.should_trial_stop(): break except eval_ff.TrainingDivergedException as error: logger.error('Training diverged: %s', str(error)) infeasible = True else: infeasible = False logger.info('Saving final checkpoint') sv.saver.save(sess, sv.save_path, global_step=sv.global_step) if tuner: # should be at the very end of execution, to avoid possible race conditions tuner.report_done(infeasible=infeasible) logger.info('report_done(infeasible=%r)', infeasible) logger.info('Done.')