def load_datasets_and_labels(params): """load class labels, in_tr_data, in_val_data, ood_val_data (called test).""" in_tr_file_list = [ os.path.join(params.in_tr_data_dir, x) for x in tf.gfile.ListDirectory(params.in_tr_data_dir) if params.in_tr_file_pattern in x ] tf.logging.info('in_tr_file_list=%s', in_tr_file_list) in_tr_label_file = [ x for x in in_tr_file_list if 'nsample' in x and '.json' in x ][0] tf.logging.info('nsample_dict_file=%s', in_tr_label_file) with tf.gfile.GFile(os.path.join(in_tr_label_file), 'rb') as f_label_code: # label_sample_size # keys: class names (strings), values: sample size of classes (ints) label_sample_size = yaml.safe_load(f_label_code) tf.logging.info('# of label_dict=%s', len(label_sample_size)) # load in-distribution training sequence, add mutations to the input seqs in_tr_data_file_list = [x for x in in_tr_file_list if '.tfrecord' in x] tf.logging.info('in_tr_data_file_list=%s', in_tr_data_file_list) def parse_single_tfexample_addmutations_short(unused_key, v): return utils.parse_single_tfexample_addmutations(unused_key, v, params.mutation_rate, params.seq_len) # for training, we optionally mutate input sequences to overcome over-fitting. if params.mutation_rate == 0: in_tr_dataset = tf.data.TFRecordDataset( in_tr_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v)) else: in_tr_dataset = tf.data.TFRecordDataset(in_tr_data_file_list).map( lambda v: parse_single_tfexample_addmutations_short(v, v)) # in-distribution validation in_val_data_file_list = [ os.path.join(params.in_val_data_dir, x) for x in tf.gfile.ListDirectory(params.in_val_data_dir) if params.in_val_file_pattern in x and '.tfrecord' in x ] tf.logging.info('in_val_data_file_list=%s', in_val_data_file_list) in_val_dataset = tf.data.TFRecordDataset( in_val_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v)) # OOD validation ood_val_data_file_list = [ os.path.join(params.ood_val_data_dir, x) for x in tf.gfile.ListDirectory(params.ood_val_data_dir) if params.ood_val_file_pattern in x and '.tfrecord' in x ] tf.logging.info('ood_val_data_file_list=%s', ood_val_data_file_list) ood_val_dataset = tf.data.TFRecordDataset( ood_val_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v)) return label_sample_size, in_tr_dataset, in_val_dataset, ood_val_dataset
def load_datasets(params, mode_eval=False): """load class labels, in_tr_data, in_val_data, ood_val_data.""" if mode_eval: # For evaluation, no need to prepare training data in_tr_dataset = None else: in_tr_file_list = [ os.path.join(params.in_tr_data_dir, x) for x in tf.gfile.ListDirectory(params.in_tr_data_dir) if params.in_tr_file_pattern in x ] # load in-distribution training sequence in_tr_data_file_list = [x for x in in_tr_file_list if '.tfrecord' in x] tf.logging.info('in_tr_data_file_list=%s', in_tr_data_file_list) def parse_single_tfexample_addmutations_short(unused_key, v): return utils.parse_single_tfexample_addmutations(unused_key, v, params.mutation_rate, params.seq_len) # for training a background model, we mutate input sequences if params.mutation_rate == 0: in_tr_dataset = tf.data.TFRecordDataset(in_tr_data_file_list).map( lambda v: utils.parse_single_tfexample(v, v)) else: in_tr_dataset = tf.data.TFRecordDataset(in_tr_data_file_list).map( lambda v: parse_single_tfexample_addmutations_short(v, v)) if params.filter_label != -1: def filter_fn(v): return filter_for_label(v, params.filter_label) in_tr_dataset = in_tr_dataset.filter(filter_fn) # in-distribution validation in_val_data_file_list = [ os.path.join(params.in_val_data_dir, x) for x in tf.gfile.ListDirectory(params.in_val_data_dir) if params.in_val_file_pattern in x and '.tfrecord' in x ] tf.logging.info('in_val_data_file_list=%s', in_val_data_file_list) in_val_dataset = tf.data.TFRecordDataset( in_val_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v)) # ood validation ood_val_data_file_list = [ os.path.join(params.ood_val_data_dir, x) for x in tf.gfile.ListDirectory(params.ood_val_data_dir) if params.ood_val_file_pattern in x and '.tfrecord' in x ] tf.logging.info('ood_val_data_file_list=%s', ood_val_data_file_list) ood_val_dataset = tf.data.TFRecordDataset( ood_val_data_file_list).map(lambda v: utils.parse_single_tfexample(v, v)) return in_tr_dataset, in_val_dataset, ood_val_dataset