コード例 #1
0
def _setup_fold(input_dir, train_dir, val_fold, val_fold_template):
    """Read the experiment proto and setup input for this validation fold.

  Args:
    input_dir: string giving the path to input data.
    train_dir: string path to directory for writing TensorFlow stuff.
    val_fold: integer ID for this validation fold (e.g., between 0 and 4)
    val_fold_template: string to include in validation fold basename.

  Returns:
    experiment_proto: selection_pb2.Experiment proto for training.
    train_input_paths: list of strings giving paths to sstables with training
      data.
    val_input_paths: list of strings giving paths to sstable(s) with validation
      data.
  """
    train_pbtxt_file = config.wetlab_experiment_train_pbtxt_path[val_fold]
    train_pbtxt_file_w_stats = six.ensure_str(train_pbtxt_file) + '.wstats'
    if gfile.Exists(os.path.join(input_dir, train_pbtxt_file_w_stats)):
        logger.info('Load pbtxt file with statistics: %s',
                    os.path.join(input_dir, train_pbtxt_file_w_stats))
        with gfile.GFile(os.path.join(input_dir,
                                      train_pbtxt_file_w_stats)) as f:
            experiment_proto = text_format.Parse(f.read(),
                                                 selection_pb2.Experiment())
    else:
        logger.info('Load pbtxt file without statistics: %s',
                    os.path.join(input_dir, train_pbtxt_file))
        with gfile.GFile(os.path.join(input_dir, train_pbtxt_file)) as f:
            experiment_proto = text_format.Parse(f.read(),
                                                 selection_pb2.Experiment())

    val_pbtxt_file = config.get_wetlab_experiment_val_pbtxt_path(
        val_fold, val_fold_template)
    gfile.Copy(os.path.join(input_dir, val_pbtxt_file),
               os.path.join(train_dir, config.wetlab_experiment_val_name),
               overwrite=True)

    train_input_paths = _expand_sharded_paths([
        os.path.join(input_dir, p)
        for n, p in enumerate(config.example_sstable_paths) if n != val_fold
    ])
    val_input_paths = _expand_sharded_paths([
        os.path.join(
            input_dir,
            config.get_example_sstable_path(val_fold, val_fold_template))
    ])

    return experiment_proto, train_input_paths, val_input_paths
コード例 #2
0
      def compare_with_best_model(checkpoint_path, summary_df, cur_epoch):
        logger.info('Comparing current val loss with the best model')

        if not gfile.Exists(best_train_report):
          logger.info('No best model saved. Adding current model...')
          update_best_model(checkpoint_path, cur_epoch)
        else:
          with gfile.GFile(best_valid_report) as f:
            with xarray.open_dataset(f) as best_ds:
              best_ds.load()
          cur_loss = summary_df['loss'].loc['mean']
          best_loss = best_ds['loss'].mean('output')
          logger.info('Current val loss:%f', cur_loss)
          logger.info('The best val loss:%f', best_loss)
          if cur_loss < best_loss:
            logger.info(
                'Current model has lower loss. Updating the best model.')
            update_best_model(checkpoint_path, cur_epoch)
          else:
            logger.info('The best model has lower loss.')
コード例 #3
0
ファイル: utils.py プロジェクト: mcarbin/code
def trial_datum_of_trial(experiment_dir, trial):
    plot_cache = os.path.join(experiment_dir, trial, 'plot_cache.pkl')
    if gfile.Exists(plot_cache):
        with gfile.Open(plot_cache, 'rb') as f:
            return pickle.loads(f.read())

    iter_dirs = sorted(
        iter_dirs_of_trial_dir(os.path.join(experiment_dir, trial)),
        key=lambda x: int(iter_re.match(os.path.basename(x)).group('iter')))

    pool = mp.Pool(5)

    res = TrialDatum(
        trial=trial,
        iter_data=list(
            filter(
                lambda x: x is not None,
                pool.map(
                    iter_datum_of_iter_dir,
                    map(lambda x: (x[1], x[0] < len(iter_dirs) - 1),
                        enumerate(iter_dirs))))),
    )
    pool.close()
    return res
コード例 #4
0
ファイル: utils.py プロジェクト: mcarbin/code
def iter_datum_of_iter_dir(iter_dir_and_can_write_cache,
                           verbose=True,
                           ignore_cache=False):
    iter_dir, can_write_cache = iter_dir_and_can_write_cache

    if not ignore_cache:
        res = read_iter(iter_dir)
        if res:
            return IterDatum(
                iter=os.path.basename(res[0]),
                density_ratio=res[1],
                test_acc=res[2],
            )
        plot_cache = os.path.join(iter_dir, 'plot_cache.pkl')
        if gfile.Exists(plot_cache):
            if verbose:
                print('PLOT CACHE EXISTS: {}'.format(plot_cache))
            with gfile.Open(plot_cache, 'rb') as f:
                try:
                    it = pickle.loads(f.read())
                    write_iter(iter_dir, it.density_ratio, it.test_acc)
                    return it
                except:
                    gfile.Remove(plot_cache)

    execution_data_iter_dir = os.path.join(
        iter_dir.replace('results', 'execution_data'), 'eval')
    if not gfile.IsDirectory(execution_data_iter_dir):
        return None

    test_acc = None
    test_iter = None
    for events_file in gfile.ListDirectory(execution_data_iter_dir):
        if not events_file.startswith('events.out'):
            continue
        for e in tf.train.summary_iterator(
                os.path.join(execution_data_iter_dir, events_file)):
            for v in e.summary.value:
                if v.tag == 'accuracy' or v.tag == 'top_1_accuracy':
                    if test_iter is None or e.step > test_iter:
                        test_iter = e.step
                        test_acc = v.simple_value

    if verbose:
        print(test_acc)

    try:
        with gfile.Open(os.path.join(iter_dir, 'density_ratio')) as f:
            density_ratio = float(f.read())
    except Exception as e:
        density_ratio = 1.0

    res = IterDatum(
        iter=os.path.basename(iter_dir),
        density_ratio=density_ratio,
        test_acc=test_acc,
    )

    if can_write_cache and test_acc is not None:
        write_iter(iter_dir, density_ratio, test_acc)
        # with gfile.Open(plot_cache, 'w') as f:
        #     f.write('')
        #     f.flush()
        # with gfile.Open(plot_cache, 'wb') as f:
        #     pickle.dump(res, f)
    return res