Beispiel #1
0
def main(args):
    if args.mode == 'prepare':  # python3 run.py  --mode prepare --pointer-gen
        prepare(args)
    elif args.mode == 'train':  # python3 run.py  --mode train -b 100 -o output --gpu 0  --restore
        train(args)
    elif args.mode == 'eval':
        # python3 run.py --mode eval --eval-model
        evaluate(args)
    elif args.mode == 'decode':  #
        # python3 run.py --mode decode --beam-size 10 --decode-model output_big_data/model/model-250000 --decode-dir output_big_data/result --gpu 1
        args.batch_size = args.beam_size
        vocab_encoder = Vocab(args, "encoder_vocab")
        vocab_decoder = Vocab(args, "decoder_vocab")
        vocab_user = User_Vocab(args, name="user_vocab")
        test_file = "./test.data"
        #test_file = os.path.join(args.data, 'chat_data/tmp.data')
        # test_file = os.path.join(args.data, 'news_train_span_50.data')
        batcher = TestBatcher(args, vocab_encoder, vocab_decoder, vocab_user,
                              test_file).batcher()
        if args.cpu:
            with tf.device('/cpu:0'):
                model = CommentModel(args, vocab_decoder)
        else:
            model = CommentModel(args, vocab_decoder)

        decoder = BeamSearchDecoder(args, model, batcher, vocab_decoder)
        decoder.decode()
    elif args.mode == 'debug':
        debug(args)
    else:
        raise RuntimeError(f'mode {args.mode} is invalid.')
Beispiel #2
0
    def __init__(self, filename, vocab_file=None,
                 vocab_dump=None, label_vocab_dump=None,
                 n_prev_turns=0, indices=None):
        with open(filename) as csvfile:
            reader = csv.DictReader(csvfile)
            self.data = [row for row in reader]

        if indices is not None:
            self.data = [self.data[i] for i in indices]

        if "id" in self.data[0]:
            self.id2idx = {row["id"]: i for i, row in enumerate(self.data)}

        self.n_prev_turns = n_prev_turns

        if vocab_dump is None:
            self.vocab = Vocab(vocab_file)
        else:
            with open(vocab_dump, 'rb') as fp:
                self.vocab = pickle.load(fp)
                
        if label_vocab_dump is None:
            labels = [row["label"] for row in self.data]
            self.label_vocab = LabelVocab(labels)
        else:
            with open(label_vocab_dump, 'rb') as fp:
                self.label_vocab = pickle.load(fp)
Beispiel #3
0
    def __init__(self,
                 data_dict,
                 train=True,
                 vocabulary=None,
                 support=False,
                 device=None):
        """
        'datas': all_datas
        'maxlen_story': maxlen_story
        'maxlen_query': maxlen_query
        'maxlen_sent': maxlen_sent
        """

        self.examples = data_dict['datas']
        self.maxlen_story = data_dict['maxlen_story']
        self.maxlen_query = data_dict['maxlen_query']
        self.maxlen_sent = data_dict['maxlen_sent']
        self.support = support
        self.device = device
        self.flatten = lambda x: [tkn for sublists in x for tkn in sublists]

        stories, questions, answers, supports = list(zip(*self.examples))
        if train:
            self.vocab = Vocab()
            self._build_vocab(stories, questions, answers)
        else:
            self.vocab = vocabulary
        # numerical & add_pad
        stories, questions, answers = self._preprocess(stories, questions,
                                                       answers)

        if self.support:
            self.data = list(zip(stories, questions, answers, supports))
        else:
            self.data = list(zip(stories, questions, answers))
Beispiel #4
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
            self.vocab.count_file(os.path.join(path, "train.txt"))
            self.vocab.count_file(os.path.join(path, "valid.txt"))
            self.vocab.count_file(os.path.join(path, "test.txt"))
        elif self.dataset == "wt103":
            self.vocab.count_file(os.path.join(path, "train.txt"))
        elif self.dataset == "lm1b":
            train_path_pattern = os.path.join(
                path, "1-billion-word-language-modeling-benchmark-r13output",
                "training-monolingual.tokenized.shuffled", "news.en-*")
            train_paths = glob(train_path_pattern)

            # the vocab will load from file when build_vocab() is called
            # for train_path in sorted(train_paths):
            #   self.vocab.count_file(train_path, verbose=True)

        self.vocab.build_vocab()

        if self.dataset in ["ptb", "wt2", "wt103"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True)
        elif self.dataset in ["enwik8", "text8"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == "lm1b":
            self.train = train_paths
            valid_path = os.path.join(path, "valid.txt")
            test_path = valid_path
            self.valid = self.vocab.encode_file(valid_path,
                                                ordered=True,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(test_path,
                                               ordered=True,
                                               add_double_eos=True)

        if self.dataset == "wt103":
            self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)]
        elif self.dataset == "lm1b":
            self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)]
        else:
            self.cutoffs = []
Beispiel #5
0
    def __init__(self,
                 filename,
                 vocab_file=None,
                 vocab_dump=None,
                 label_vocab_dump=None,
                 n_prev_turns=0,
                 text_input=False):
        self.text_input = text_input
        with open(filename) as csvfile:
            reader = csv.DictReader(csvfile)
            self.data = [row for row in reader]
            lattice_reader = LatticeReader(text_input=text_input)
            for i, row in enumerate(tqdm(self.data)):
                row["lattice"] = lattice_reader.read_sent(row["text"], i)
                row["rev_lattice"] = row["lattice"].reversed()

        self.id2idx = {row["id"]: i for i, row in enumerate(self.data)}
        self.n_prev_turns = n_prev_turns
        if vocab_dump is None:
            self.vocab = Vocab(vocab_file)
        else:
            with open(vocab_dump, 'rb') as fp:
                self.vocab = pickle.load(fp)
        if label_vocab_dump is None:
            labels = [row["label"] for row in self.data]
            self.label_vocab = LabelVocab(labels)
        else:
            with open(label_vocab_dump, 'rb') as fp:
                self.label_vocab = pickle.load(fp)
Beispiel #6
0
    def add_tokens(self, tokens):
        Vocab.add_tokens(self, tokens)

        if self.counts is not None:
            self.counts.update(tokens)

        if self.doc_counts is not None:
            token_set = set(tokens)
            self.doc_counts.update(token_set)
Beispiel #7
0
    def add_tokens(self, tokens):
        Vocab.add_tokens(self, tokens)

        if self.counts is not None:
            self.counts.update(tokens)

        if self.doc_counts is not None:
            token_set = set(tokens)
            self.doc_counts.update(token_set)
Beispiel #8
0
def test_save_load_vocab():
    freq_dict = {"a": 2, "c": 10, "int": 100, "A": 1}
    vocab = Vocab.create_from_freq_dict(freq_dict)
    with open("dump_vocab.v.c2v", 'wb') as file:
        vocab.save_to_file(file)
    with open("dump_vocab.v.c2v", 'rb') as file:
        new_vocab = Vocab.load_from_file(file)
    assert vocab.word_to_index == new_vocab.word_to_index
    assert vocab.index_to_word == new_vocab.index_to_word
Beispiel #9
0
    def __init__(self, prefix, add_oov=True, read_from_filename=None, tokens_to_add=None):
        Vocab.__init__(self, prefix, add_oov=add_oov)
        self.counts = Counter()
        self.doc_counts = Counter()

        if read_from_filename is not None:
            self.read_from_file(read_from_filename)

        if tokens_to_add is not None:
            self.add_tokens(tokens_to_add)
Beispiel #10
0
def prepare(args):
    if not os.path.exists(args.records_dir):
        os.makedirs(args.records_dir)

    train_file = os.path.join(args.data, 'chat_data/tmp.data')
    dev_file = os.path.join(args.data, 'chat_data/tmp.data')
    vocab_encoder = Vocab(args, name="encoder_vocab")
    vocab_decoder = Vocab(args, name="decoder_vocab")
    vocab_user = User_Vocab(args, name="user_vocab")
    dataset = Dataset(args, vocab_encoder, vocab_decoder, vocab_user,
                      train_file, dev_file)
    dataset.save_datasets(['train', 'dev'])
Beispiel #11
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, "train.txt"))
        self.vocab.build_vocab()

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True)
        self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True)

        vocab_len = len(self.vocab)
        self.cutoffs = [0, int(vocab_len * 0.1), int(vocab_len * 0.2), int(vocab_len * 0.4)] + [vocab_len]
Beispiel #12
0
    def __init__(self,
                 prefix,
                 add_oov=True,
                 read_from_filename=None,
                 tokens_to_add=None):
        Vocab.__init__(self, prefix, add_oov=add_oov)
        self.counts = Counter()
        self.doc_counts = Counter()

        if read_from_filename is not None:
            self.read_from_file(read_from_filename)

        if tokens_to_add is not None:
            self.add_tokens(tokens_to_add)
Beispiel #13
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, "train.txt"))
        self.vocab.build_vocab()

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"),
                                            ordered=True)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"),
                                            ordered=True)
        self.test = self.vocab.encode_file(os.path.join(path, "train.txt"),
                                           ordered=True)

        self.cutoffs = []
    def __init__(self, path, dataset, *args, **kwargs):

        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(
            path, "train.txt"))  # 更新vocab对象里的counter(用于统计每个不同的词出现的次数)
        self.vocab.count_file(os.path.join(path, "valid.txt"))  # 同上,验证集中更新

        self.vocab.build_vocab()  # 这一步是为了建立idx2sym和sym2idx,把词映射为索引,把索引还原为词

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"),
                                            ordered=True)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"),
                                            ordered=True)
Beispiel #15
0
    def __init__(self,
                 filename,
                 vocab_file=None,
                 vocab_dump=None,
                 stop_word_file=None):
        with open(filename) as csvfile:
            reader = csv.DictReader(csvfile)
            data = [row for row in reader]

        self.stop_words = set()
        if stop_word_file is not None:
            for line in open(stop_word_file):
                self.stop_words.add(line.strip())

        datas = []
        count, total = 0, 0
        for row in data:
            ref = row["transcription"]
            hyp = row["hypothesis"]
            score = float(row["score"])
            confs = row["confusion"].split()
            confs = [(confs[i * 3], confs[i * 3 + 1])
                     for i in range(len(confs) // 3 + 1)]
            conf_ids = []
            ref_id = hyp_id = 0
            for ref_w, hyp_w in confs:
                ref_eps = (ref_w == "<eps>")
                hyp_eps = (hyp_w == "<eps>")
                if not ref_eps and not hyp_eps and ref_w != hyp_w:
                    total += 1
                    if ref_w not in self.stop_words and hyp_w not in self.stop_words:
                        conf_ids.append((ref_id, hyp_id))
                    else:
                        count += 1

                if not ref_eps:
                    ref_id += 1
                if not hyp_eps:
                    hyp_id += 1
            datas.append((ref, hyp, conf_ids, score))
        print(count, total)
        self.data = datas

        if vocab_file is not None:
            self.vocab = Vocab(vocab_file)
        elif vocab_dump is not None:
            with open(vocab_dump, 'rb') as fp:
                self.vocab = pickle.load(fp)
    def __init__(self, dataset, batch_size=None, vocab_created=False, vocab=None, target_col=None, word2index=None,
             sos_token='<SOS>', eos_token='<EOS>', unk_token='<UNK>', pad_token='<PAD>', min_word_count=5,
             max_vocab_size=None, max_seq_len=0.8, use_pretrained_vectors=False, glove_path='Glove/',
             glove_name='glove.6B.100d.txt', weights_file_name='Glove/weights.npy'):

        if not vocab_created:
            self.vocab = Vocab(dataset, target_col=target_col, word2index=word2index, sos_token=sos_token, eos_token=eos_token,
                               unk_token=unk_token, pad_token=pad_token, min_word_count=min_word_count,
                               max_vocab_size=max_vocab_size, max_seq_len=max_seq_len,
                               use_pretrained_vectors=use_pretrained_vectors, glove_path=glove_path,
                               glove_name=glove_name, weights_file_name=weights_file_name)

            self.dataset = self.vocab.dataset

        else:
            self.dataset = dataset
            self.vocab = vocab

        self.target_col = target_col

        self.word2index = self.vocab.word2index

        if batch_size:
            self.batch_size = batch_size
        else:
            self.batch_size = len(self.dataset)

        self.x_lengths = np.array(self.vocab.x_lengths)

        if self.target_col:
            self.y_lengths = np.array(self.vocab.y_lengths)

        self.pad_token = self.vocab.word2index[pad_token]

        self.sort_and_batch()
Beispiel #17
0
 def word(self,update,context):
     try:
         chat_message=update.message.text
         chat_message.lower().capitalize()
         x=Vocab(chat_message).mean()
         context.bot.send_message(chat_id=update.effective_chat.id,text=x)
     except KeyError:
         context.bot.send_message(chat_id=update.effective_chat.id,text="Ä°nvaild Syntax :(")
Beispiel #18
0
 def Vocabulary(self, update, context):
     try:
         chat_message = update.message.text
         x = Vocab(chat_message).mean()
         context.bot.send_message(chat_id=update.effective_chat.id, text=x)
     except KeyError:
         context.bot.send_message(chat_id=update.effective_chat.id,
                                  text="İnvaild Syntax :(")
Beispiel #19
0
    def __init__(self,
                 filename,
                 vocab_file=None,
                 vocab_dump=None,
                 text_input=False):
        self.text_input = text_input
        with open(filename) as csvfile:
            reader = csv.DictReader(csvfile)
            self.data = [row for row in reader]
            lattice_reader = LatticeReader(text_input=text_input)
            for i, row in enumerate(tqdm(self.data)):
                row["lattice"] = lattice_reader.read_sent(row["text"], i)
                row["rev_lattice"] = row["lattice"].reversed()

        if vocab_dump is None:
            self.vocab = Vocab(vocab_file)
        else:
            with open(vocab_dump, 'rb') as fp:
                self.vocab = pickle.load(fp)
Beispiel #20
0
    def __init__(self,
                 filename,
                 vocab_file=None,
                 vocab_dump=None,
                 label_vocab_dump=None):
        with open(filename) as csvfile:
            reader = csv.DictReader(csvfile)
            self.data = [row for row in reader]

        if vocab_dump is None:
            self.vocab = Vocab(vocab_file)
        else:
            with open(vocab_dump, 'rb') as fp:
                self.vocab = pickle.load(fp)
        if label_vocab_dump is None:
            labels = [row["label"] for row in self.data]
            self.label_vocab = LabelVocab(labels)
        else:
            with open(label_vocab_dump, 'rb') as fp:
                self.label_vocab = pickle.load(fp)
Beispiel #21
0
def test_create_from_freq_dict():
    freq_dict = {"a": 2, "c": 10, "int": 100, "A": 1}
    vocab = Vocab.create_from_freq_dict(freq_dict)
    assert {
        word: i
        for i, word in enumerate(['NOTHING', 'A', 'a', 'c', 'int'])
    } == vocab.word_to_index
    assert {
        i: word
        for i, word in enumerate(['NOTHING', 'A', 'a', 'c', 'int'])
    } == vocab.index_to_word
Beispiel #22
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        train_path = os.path.join(path, "train.txt")
        valid_path = os.path.join(path, "valid.txt")
        # test_path = os.path.join(path, "test.txt")

        # self.vocab.count_file(train_path)
        # self.vocab.count_file(valid_path)
        # self.vocab.count_file(test_path)
        self.vocab.build_vocab(add_bytes=True)

        self.train = train_path
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"),
                                            ordered=True,
                                            add_eos=False)
        # self.test  = self.vocab.encode_file(
        #     os.path.join(path, "test.txt"), ordered=True, add_eos=False)
        self.cutoffs = []
Beispiel #23
0
    def __init__(
            self, checkpoint_path='/home/mnakhodnov/sirius-stt/models/8_recovered_v3/epoch_17.pt',
            device=torch.device('cpu'), rescore=True, decoder_kwargs=None
    ):
        if not os.path.exists(checkpoint_path):
            raise ValueError(f'There is no checkpoint in {checkpoint_path}')

        self.device = device
        self.rescore = rescore
        self.decoder_kwargs = decoder_kwargs
        self.checkpoint_path = checkpoint_path

        self._vocab = Vocab(self._alphabet)

        self._num_tokens = get_num_tokens(self._vocab)
        self._blank_index = get_blank_index(self._vocab)

        self._sample_rate = 8000
        self._model_config = {
            'num_mel_bins': 64,
            'hidden_size': 512,
            'num_layers': 4,
            'num_tokens': len(self._vocab.tokens2indices()) - 1,
        }

        self.model = Model(**self._model_config)
        load_from_ckpt(self.model, self.checkpoint_path)
        self.model = self.model.to(device=self.device).eval()

        self.decoder = fast_beam_search_decode
        self._kenlm_binary_path = '/data/mnakhodnov/language_data/cc100/xaa.processed.3.binary'
        if self.decoder_kwargs is None:
            self.decoder_kwargs = {
                'beam_size': 200, 'cutoff_top_n': 33, 'cutoff_prob': 1.0,
                'ext_scoring_func': self._kenlm_binary_path, 'alpha': 1.0, 'beta': 0.3, 'num_processes': 32
            }

        if self.rescore:
            self.rescorer_model = torch.hub.load(
                'pytorch/fairseq', 'transformer_lm.wmt19.ru', tokenizer='moses', bpe='fastbpe', force_reload=False
            ).to(device=device)
Beispiel #24
0
class PairDataset(Dataset):
    label_idx = 3

    def __init__(self,
                 filename,
                 vocab_file=None,
                 vocab_dump=None,
                 label_vocab_dump=None):
        with open(filename) as csvfile:
            reader = csv.DictReader(csvfile)
            self.data = [row for row in reader]

        if vocab_dump is None:
            self.vocab = Vocab(vocab_file)
        else:
            with open(vocab_dump, 'rb') as fp:
                self.vocab = pickle.load(fp)
        if label_vocab_dump is None:
            labels = [row["label"] for row in self.data]
            self.label_vocab = LabelVocab(labels)
        else:
            with open(label_vocab_dump, 'rb') as fp:
                self.label_vocab = pickle.load(fp)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def _process_text(self, text):
        for punct in [',', '.', '?', '!']:
            if text.endswith(f" {punct}"):
                text = text[:-2]
        text = re.sub(" ([a-z])\. ", " \\1 ", text)
        return text

    def collate_fn(self, batch):
        inputs, words, positions, labels = [], [], [], []
        for utt in batch:
            text = self._process_text(utt["text"] + " " + utt["text2"])
            text = " ".join(text)
            label = utt["label"]
            word_ids = [self.vocab.w2i(word) for word in text.split()]
            words.append(text.split())
            inputs.append(word_ids)
            positions.append([0, len(word_ids)])
            labels.append(self.label_vocab.l2i(label))

        max_length = max(map(len, inputs))
        inputs = pad_sequences(inputs, max_length)
        labels = np.array(labels)
        return inputs, words, positions, labels
Beispiel #25
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, "train.txt"))
        self.vocab.build_vocab()

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True)
        self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True)

        vocab_len = len(self.vocab)
        self.cutoffs = [0, int(vocab_len * 0.1), int(vocab_len * 0.2), int(vocab_len * 0.4)] + [vocab_len]
        # self.cutoffs = []

    def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len, num_core_per_host, **kwargs):
        file_names = []

        record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)

        data = getattr(self, split)

        file_name, num_batch = create_ordered_tfrecords(save_dir, split, data, bsz, tgt_len)
        file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {
                "filenames": file_names,
                "num_batch": num_batch
            }
            json.dump(record_info, fp)
Beispiel #26
0
    def __init__(self, text_path, vocab_file=None, vocab_dump=None):
        self.data = []

        print_time_info("Reading text from {}".format(text_path))

        with open(text_path) as csvfile:
            reader = csv.DictReader(csvfile)
            for i, row in enumerate(reader):
                words = row["text"].split()
                if "id" in row:
                    self.data.append((row["id"], words))
                else:
                    self.data.append((i, words))
        # for line in tqdm(open(text_path)):
        #     uid, *words = line.strip().split()
        #     self.data.append((uid, words))

        if vocab_dump is None:
            self.vocab = Vocab(vocab_file)
        else:
            with open(vocab_dump, 'rb') as fp:
                self.vocab = pickle.load(fp)
Beispiel #27
0
def test_create_lookup_table():
    freq_dict = {"a": 2, "c": 10, "int": 100, "A": 1}
    vocab = Vocab.create_from_freq_dict(freq_dict, 0)
    w_t_i_lookup_table = vocab.get_word_to_index_lookup_table()
    i_t_w_lookup_table = vocab.get_index_to_word_lookup_table()
    for index, word in enumerate([
            "NOTHING",
            *sorted([freq_dict.keys()], key=lambda key: freq_dict[key])
    ]):
        assert w_t_i_lookup_table.lookup(tf.constant(
            word, dtype=tf.string)).numpy() == index
        assert i_t_w_lookup_table.lookup(tf.constant(
            index, dtype=tf.int32)) == tf.constant(word, dtype=tf.string)
Beispiel #28
0
def train(**kwargs):
    args = DefaultConfig()
    args.parse(kwargs)
    vocab = Vocab()
    loss_functions = transformer_celoss
    score_functions = rouge_func
    model = getattr(Models, args.model_name)(vocab, args)
    train_loader = get_loaders('train', args.batch_size, 12)
    dev_loader = get_loaders('val', args.batch_size, 12)
    trainer = ScheduledTrainerTrans(args, model, loss_functions, score_functions, train_loader, dev_loader)
    if args.resume is not None:
        trainer.init_trainner(resume_from=args.resume)
    else:
        trainer.init_trainner()
    trainer.train()
Beispiel #29
0
class LMDataset(Dataset):
    def __init__(self, text_path, vocab_file=None, vocab_dump=None):
        self.data = []

        print_time_info("Reading text from {}".format(text_path))

        with open(text_path) as csvfile:
            reader = csv.DictReader(csvfile)
            for i, row in enumerate(reader):
                words = row["text"].split()
                if "id" in row:
                    self.data.append((row["id"], words))
                else:
                    self.data.append((i, words))
        # for line in tqdm(open(text_path)):
        #     uid, *words = line.strip().split()
        #     self.data.append((uid, words))

        if vocab_dump is None:
            self.vocab = Vocab(vocab_file)
        else:
            with open(vocab_dump, 'rb') as fp:
                self.vocab = pickle.load(fp)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        uid, sentence = self.data[index]
        word_ids = [self.vocab.w2i(word) for word in sentence]
        return uid, sentence, word_ids

    def collate_fn(self, batch):
        uids, inputs, outputs, outputs_rev = [], [], [], []
        for uid, words, word_ids in batch:
            uids.append(uid)
            inputs.append(words)
            outputs.append(word_ids[1:] + [PAD])
            outputs_rev.append([PAD] + word_ids[:-1])

        max_length = max([len(sent) for sent in outputs])
        # (batch_size, seq_length)
        outputs = pad_sequences(outputs, max_length, 'post')
        outputs_rev = pad_sequences(outputs_rev, max_length, 'post')

        return inputs, outputs, outputs_rev, uids
Beispiel #30
0
def load_model(exp_name):
    exp_root = os.path.join(ckpt_root, exp_name)
    best_model_folder = get_best_k_model_path(os.path.join(exp_root, 'saved_models'))[0]
    best_model_folder = os.path.join(exp_root, 'saved_models', best_model_folder)
    model_state = t.load(os.path.join(best_model_folder, 'model'), map_location='cpu')
    try:
        for i in model_state:
            model_state[i] = model_state[i].cpu()
    except:
        pass

    trainner_state = t.load(os.path.join(best_model_folder, 'trainner_state'))
    args = trainner_state['args']

    vocab = Vocab()
    model = getattr(Models, args.model_name)(vocab, args)
    model.load_state_dict(model_state)
    model.eval()
    return model
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):

        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(
            path, "train.txt"))  # 更新vocab对象里的counter(用于统计每个不同的词出现的次数)
        self.vocab.count_file(os.path.join(path, "valid.txt"))  # 同上,验证集中更新

        self.vocab.build_vocab()  # 这一步是为了建立idx2sym和sym2idx,把词映射为索引,把索引还原为词

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"),
                                            ordered=True)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"),
                                            ordered=True)

        # self.cutoffs = []  # 完全是多余的,从看代码的第一天开始,我就觉得cutoff是多余的,在今天被坑了一天之后,我终于可以确定在没有TPU的情况下,所有设涉及cutoff的代码都是多余的

    def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len, **kwargs):
        file_names = []

        record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
            split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)
        bin_sizes = None

        file_name, num_batch = create_ordered_tfrecords(
            save_dir, split, getattr(self, split), bsz, tgt_len)

        file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {
                "filenames": file_names,
                "bin_sizes": bin_sizes,
                "num_batch": num_batch
            }
            json.dump(record_info, fp)
Beispiel #32
0
def train_re(**kwargs):
    args = DefaultConfig()
    args.parse(kwargs)
    vocab = Vocab()
    loss_functions = transformer_celoss
    score_functions = rouge_func
    model = getattr(Models, args.model_name)(vocab, args)
    train_loader = get_loaders('train', args.batch_size, 12)
    dev_loader = get_loaders('val', args.batch_size, 12)
    trainer = ScheduledTrainerTrans(args, model, loss_functions, score_functions, train_loader, dev_loader)
    trainer.init_trainner(resume_from=args.resume)
    # try:
    #     trainer.model.vgg_feature.requires_grad = True
    #     trainer.model.vgg_input.requires_grad = True
    #
    # except:
    #     trainer.model.module.vgg_feature.requires_grad = True
    #     trainer.model.module.vgg_input.requires_grad = True
    # trainer.optim.param_groups[0]['lr'] = 3e-5
    trainer.train()