Esempio n. 1
0
def load_datasets_and_labels(params):
  """load class labels, in_tr_data, in_val_data, ood_val_data (called test)."""
  in_tr_file_list = [
      os.path.join(params.in_tr_data_dir, x)
      for x in tf.gfile.ListDirectory(params.in_tr_data_dir)
      if params.in_tr_file_pattern in x
  ]
  tf.logging.info('in_tr_file_list=%s', in_tr_file_list)

  in_tr_label_file = [
      x for x in in_tr_file_list if 'nsample' in x and '.json' in x
  ][0]
  tf.logging.info('nsample_dict_file=%s', in_tr_label_file)
  with tf.gfile.GFile(os.path.join(in_tr_label_file), 'rb') as f_label_code:
    # label_sample_size
    # keys: class names (strings), values: sample size of classes (ints)
    label_sample_size = yaml.safe_load(f_label_code)
    tf.logging.info('# of label_dict=%s', len(label_sample_size))

  # load in-distribution training sequence, add mutations to the input seqs
  in_tr_data_file_list = [x for x in in_tr_file_list if '.tfrecord' in x]
  tf.logging.info('in_tr_data_file_list=%s', in_tr_data_file_list)

  def parse_single_tfexample_addmutations_short(unused_key, v):
    return utils.parse_single_tfexample_addmutations(unused_key, v,
                                                     params.mutation_rate,
                                                     params.seq_len)

  # for training, we optionally mutate input sequences to overcome over-fitting.
  if params.mutation_rate == 0:
    in_tr_dataset = tf.data.TFRecordDataset(
        in_tr_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v))
  else:
    in_tr_dataset = tf.data.TFRecordDataset(in_tr_data_file_list).map(
        lambda v: parse_single_tfexample_addmutations_short(v, v))

  # in-distribution validation
  in_val_data_file_list = [
      os.path.join(params.in_val_data_dir, x)
      for x in tf.gfile.ListDirectory(params.in_val_data_dir)
      if params.in_val_file_pattern in x and '.tfrecord' in x
  ]
  tf.logging.info('in_val_data_file_list=%s', in_val_data_file_list)
  in_val_dataset = tf.data.TFRecordDataset(
      in_val_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v))

  # OOD validation
  ood_val_data_file_list = [
      os.path.join(params.ood_val_data_dir, x)
      for x in tf.gfile.ListDirectory(params.ood_val_data_dir)
      if params.ood_val_file_pattern in x and '.tfrecord' in x
  ]
  tf.logging.info('ood_val_data_file_list=%s', ood_val_data_file_list)
  ood_val_dataset = tf.data.TFRecordDataset(
      ood_val_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v))

  return label_sample_size, in_tr_dataset, in_val_dataset, ood_val_dataset
def load_datasets(params, mode_eval=False):
  """load class labels, in_tr_data, in_val_data, ood_val_data."""
  if mode_eval:  # For evaluation, no need to prepare training data
    in_tr_dataset = None
  else:
    in_tr_file_list = [
        os.path.join(params.in_tr_data_dir, x)
        for x in tf.gfile.ListDirectory(params.in_tr_data_dir)
        if params.in_tr_file_pattern in x
    ]

    # load in-distribution training sequence
    in_tr_data_file_list = [x for x in in_tr_file_list if '.tfrecord' in x]
    tf.logging.info('in_tr_data_file_list=%s', in_tr_data_file_list)

    def parse_single_tfexample_addmutations_short(unused_key, v):
      return utils.parse_single_tfexample_addmutations(unused_key, v,
                                                       params.mutation_rate,
                                                       params.seq_len)

    # for training a background model, we mutate input sequences
    if params.mutation_rate == 0:
      in_tr_dataset = tf.data.TFRecordDataset(in_tr_data_file_list).map(
          lambda v: utils.parse_single_tfexample(v, v))
    else:
      in_tr_dataset = tf.data.TFRecordDataset(in_tr_data_file_list).map(
          lambda v: parse_single_tfexample_addmutations_short(v, v))

    if params.filter_label != -1:

      def filter_fn(v):
        return filter_for_label(v, params.filter_label)

      in_tr_dataset = in_tr_dataset.filter(filter_fn)

  # in-distribution validation
  in_val_data_file_list = [
      os.path.join(params.in_val_data_dir, x)
      for x in tf.gfile.ListDirectory(params.in_val_data_dir)
      if params.in_val_file_pattern in x and '.tfrecord' in x
  ]
  tf.logging.info('in_val_data_file_list=%s', in_val_data_file_list)
  in_val_dataset = tf.data.TFRecordDataset(
      in_val_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v))

  # ood validation
  ood_val_data_file_list = [
      os.path.join(params.ood_val_data_dir, x)
      for x in tf.gfile.ListDirectory(params.ood_val_data_dir)
      if params.ood_val_file_pattern in x and '.tfrecord' in x
  ]
  tf.logging.info('ood_val_data_file_list=%s', ood_val_data_file_list)
  ood_val_dataset = tf.data.TFRecordDataset(
      ood_val_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v))

  return in_tr_dataset, in_val_dataset, ood_val_dataset