def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_dir", "~/t2t_data", type=str) parser.add_argument("--tmp_dir", "/tmp/t2t_data", type=str) args = parser.parse_args() gen = lm1b.LanguagemodelLm1b32k() data_dir = os.expanddir(args.data_dir) gen.generate_data(data_dir, args.tmp_dir) examples = [] batch_size = 1024 batches = 0 for ex in tfrecord_iterator_for_problem(gen, data_dir): examples.append(ex["targets"].values) if len(examples) == batch_size: max_len = max(map(len, examples)) batch = np.zeros((len(examples), max_len), dtype=np.int32) batch.fill(-1) for idx, ex in enumerate(examples): batch[idx, :len(ex)] = ex np.save(os.path.join(SCRIPT_DIR, "batch{}.npy".format(batches)), batch) batches += 1 examples.clear() if batches % 100 == 0: print("batches {}".format(batches))
def vocab_filename(self): return lm1b.LanguagemodelLm1b32k().vocab_filename
def use_vocab_from_other_problem(self): return lm1b.LanguagemodelLm1b32k()
def __init__(self, was_reversed=False, was_copy=False): super(LanguagemodelLm1bMultiNLISubwords, self).__init__(was_reversed, was_copy) self.task_list.append(lm1b.LanguagemodelLm1b32k()) self.task_list.append(multinli.MultiNLISharedVocab())