def load_ood_dataset(n_ood, fix_len, vocab, vocab_size): """Load LM1B dataset for OOD test.""" ## import ood dataset data = tfds.load('lm1b') _, test_lm1b_tfds = data['train'], data['test'] i = 0 test_lm1b_list = [] for example in tfds.as_numpy(test_lm1b_tfds): x = example['text'].decode('utf-8') # for PY3 x = re.sub(r'\W+', ' ', x).strip() # remove "," "." test_lm1b_list.append(x) i += 1 if i % n_ood == 0: break test_lm1b_x = data_utilsh.text_to_rank( test_lm1b_list, vocab, desired_vocab_size=vocab_size) # pad text to achieve the same length test_lm1b_x_pad = data_utilsh.pad_sequences(test_lm1b_x, maxlen=fix_len) test_lm1b_y = -1 * np.ones(len(test_lm1b_x)) return test_lm1b_x_pad, test_lm1b_y
def make_dataset(params): """Make np arrays for 20 news groups.""" # odd number classes are in-distribution, even number classes are OODs. if params['filter_label'] == -1: to_include = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] else: to_include = [params['filter_label']] to_exclude = list(set(range(params['n_class'])).difference(to_include)) logging.info('exclude classes=%s', to_exclude) logging.info('Loading raw data') x_train, y_train = data_utilsh.load_data(params['tr_data_file']) x_test, y_test = data_utilsh.load_data(params['test_data_file']) logging.info('Get vocab and encode words to ints') # vocab is a dict ordered by word freqs vocab = data_utilsh.get_vocab(x_train) # words with top vocab_size-1 freqs are encoded as 1 to vocab_size-1, # words with less freqs are encoded as vocab_size for unknown token # sentences > max_len is truncated from the beginning, # sentences < max_len is padded from the beginning with 0. # so the total vocab = vocab_size-1 (specific words) + 1 (unk) + 1 (padding) x_train = data_utilsh.text_to_rank(x_train, vocab, params['vocab_size']) x_test = data_utilsh.text_to_rank(x_test, vocab, params['vocab_size']) # shuffle np.random.seed(params['random_seed']) indices = np.arange(len(x_train)) np.random.shuffle(indices) x_train = [x_train[i] for i in indices] y_train = [y_train[i] for i in indices] indices = np.arange(len(x_test)) np.random.shuffle(indices) x_test = [x_test[i] for i in indices] y_test = [y_test[i] for i in indices] # split into train/dev n_dev = int(len(x_train) * (1 - params['tr_frac'])) x_dev = x_train[-n_dev:] y_dev = y_train[-n_dev:] x_train = x_train[:-n_dev] y_train = y_train[:-n_dev] # if fragment text into short pieces if params['shortfrag']: logging.info('sampling sub-text.') x_train, y_train = data_utils.fragment_into_short_sentence( x_train, y_train, params['fix_len'], params['sample_rate']) logging.info('x_train_frag=%s', x_train[0]) x_dev, y_dev = data_utils.fragment_into_short_sentence( x_dev, y_dev, params['fix_len'], params['sample_rate']) logging.info('x_dev_frag=%s', x_dev[0]) x_test, y_test = data_utils.fragment_into_short_sentence( x_test, y_test, params['fix_len'], params['sample_rate']) logging.info('x_test_frag=%s', x_test[0]) else: logging.info('pad original text with 0s.') # pad text to achieve the same length x_train = data_utilsh.pad_sequences(x_train, maxlen=params['fix_len']) x_dev = data_utilsh.pad_sequences(x_dev, maxlen=params['fix_len']) x_test = data_utilsh.pad_sequences(x_test, maxlen=params['fix_len']) y_train = np.array(y_train) y_dev = np.array(y_dev) y_test = np.array(y_test) # partition data into in-distribution and OODs by their labels in_sample_examples, in_sample_labels, oos_examples, oos_labels =\ data_utilsh.partion_data_in_two(x_train, np.array(y_train), to_include, to_exclude) dev_in_sample_examples, dev_in_sample_labels, dev_oos_examples, dev_oos_labels =\ data_utilsh.partion_data_in_two(x_dev, np.array(y_dev), to_include, to_exclude) test_in_sample_examples, test_in_sample_labels, test_oos_examples, test_oos_labels =\ data_utilsh.partion_data_in_two(x_test, np.array(y_test), to_include, to_exclude) class_freq = np.bincount(in_sample_labels) logging.info('in_sample_labels_freq=%s', class_freq) class_freq = np.bincount(dev_in_sample_labels) logging.info('dev_in_sample_labels_freq=%s', class_freq) class_freq = np.bincount(dev_oos_labels) logging.info('dev_oos_labels_freq=%s', class_freq) # relabel in-distribution from 0,2,4... to 0,1,2, for encoding labels # when training classifier # safely assumes there is an example for each in_sample class i # n both the training and dev class in_sample_labels = data_utilsh.relabel_in_sample_labels(in_sample_labels) dev_in_sample_labels = data_utilsh.relabel_in_sample_labels( dev_in_sample_labels) test_in_sample_labels = data_utilsh.relabel_in_sample_labels( test_in_sample_labels) logging.info('# word id>15000=%s', np.sum(in_sample_labels > 15000)) logging.info( 'n_tr_in=%s, n_val_in=%s, n_val_ood=%s, n_test_in=%s, n_test_ood=%s', in_sample_labels.shape[0], dev_in_sample_labels.shape[0], dev_oos_labels.shape[0], test_in_sample_labels.shape[0], test_oos_labels.shape[0]) logging.info('example in_sample_examples1=%s, \n in_sample_examples2=%s', in_sample_examples[0], in_sample_examples[1]) ## save to disk if params['shortfrag']: # if fragment text into fix-length short pieces, # we subsample short pieces with sample_rate # so the data file name has the this parameter out_file_name = '20news_encode_maxlen{}_vs{}_rate{}_in{}_trfrac{}.pkl'.format( params['fix_len'], params['vocab_size'], params['sample_rate'], '-'.join([str(x) for x in to_include]), params['tr_frac']) else: # if we do not fragment text, we use all text examples # Given a fixed length, pad text if it is shorter than the fixed length, # truncate text if it is longer than the fixed length. out_file_name = '20news_encode_maxlen{}_vs{}_in{}_trfrac{}.pkl'.format( params['fix_len'], params['vocab_size'], '-'.join([str(x) for x in to_include]), params['tr_frac']) with tf.gfile.Open(os.path.join(params['out_dir'], out_file_name), 'wb') as f: pickle.dump(in_sample_examples, f) pickle.dump(in_sample_labels, f) pickle.dump(oos_examples, f) pickle.dump(oos_labels, f) pickle.dump(dev_in_sample_examples, f) pickle.dump(dev_in_sample_labels, f) pickle.dump(dev_oos_examples, f) pickle.dump(dev_oos_labels, f) pickle.dump(test_in_sample_examples, f) pickle.dump(test_in_sample_labels, f) pickle.dump(test_oos_examples, f) pickle.dump(test_oos_labels, f) pickle.dump(vocab, f)