예제 #1
0
def load_word_vectors(path):
    if os.path.isfile(path + '.pth') and os.path.isfile(path + '.vocab'):
        print('==> File found, loading to memory')
        vectors = torch.load(path + '.pth')
        vocab = Vocab(filename=path + '.vocab')
        return vocab, vectors
    # saved file not found, read from txt file
    # and create tensors for word vectors
    print('==> File not found, preparing, be patient')
    count = sum(
        1
        for line in open(path + '.txt', 'r', encoding='utf8', errors='ignore'))
    with open(path + '.txt', 'r') as f:
        contents = f.readline().rstrip('\n').split(' ')
        dim = len(contents[1:])
    words = [None] * (count)
    vectors = torch.zeros(count, dim, dtype=torch.float, device='cpu')
    with open(path + '.txt', 'r', encoding='utf8', errors='ignore') as f:
        idx = 0
        for line in f:
            contents = line.rstrip('\n').split(' ')
            words[idx] = contents[0]
            values = list(map(float, contents[1:]))
            vectors[idx] = torch.tensor(values,
                                        dtype=torch.float,
                                        device='cpu')
            idx += 1
    with open(path + '.vocab', 'w', encoding='utf8', errors='ignore') as f:
        for word in words:
            f.write(word + '\n')
    vocab = Vocab(filename=path + '.vocab')
    torch.save(vectors, path + '.pth')
    return vocab, vectors
예제 #2
0
파일: predict.py 프로젝트: royyoung388/srl
    def __init__(self,
                 model_path,
                 word_vocab,
                 label_vocab,
                 word,
                 label,
                 device='cpu'):
        # load vocab
        self.word_vocab = Vocab(word_vocab)
        self.label_vocab = Vocab(label_vocab)

        self.word_vocab.unk_id = self.word_vocab.toID(UNK)
        # todo just for test
        self.label_vocab.unk_id = 1
        config.WORD_PAD_ID = self.word_vocab.toID(PAD)
        config.WORD_UNK_ID = self.word_vocab.toID(UNK)
        config.LABEL_PAD_ID = self.label_vocab.toID(PAD)
        pred_id = [self.label_vocab.toID('B-v')]

        self.device = torch.device(device)

        # load data
        dataset = DataReader(word, label, self.word_vocab, self.label_vocab)
        self.dataLoader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            # pin_memory=True,
            shuffle=False,
            collate_fn=Collate(pred_id, WORD_PAD_ID, LABEL_PAD_ID, False))

        self.model = DeepAttn(self.word_vocab.size(), self.label_vocab.size(),
                              feature_dim, model_dim, filter_dim)
        self.model.load_state_dict(torch.load(model_path, map_location=device))
        self.model.to(self.device)
예제 #3
0
def read_word_data(path, word_vocab: Vocab):
    sentence = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            words = line.strip().split()
            sentence.append(torch.LongTensor(word_vocab.toID(words)))
    return sentence
예제 #4
0
    def __init__(self, vocab: Vocab, sequence, spans: Optional[List[List[MatchedSpan]]] = None,
                 **other_seqs: Tuple[List[List[T]], T]):
        """Instantiate a batch of sequence."""
        self.raw_sequence = sequence
        self.spans = spans
        self.unk_probs = None
        self.batch_size = len(self.raw_sequence)

        indices = [vocab.numericalize(s) for s in self.raw_sequence]  # TODO: Improve performance (47.582s)
        sequence = pad_simple(indices, pad_symbol=1)
        self.sequence = torch.from_numpy(sequence[:, :-1])
        self.target = torch.from_numpy(sequence[:, 1:])

        self.seqs: Dict[str, LongTensor] = {}
        self.seqs_pad_symbols: Dict[str, T] = {}
        for name, (seq, pad_symbol) in other_seqs.items():
            self.seqs[name] = torch.from_numpy(pad_simple(seq, pad_symbol=pad_symbol))
            self.seqs_pad_symbols[name] = pad_symbol

        # don't make length -1 for empty sequences
        self.lengths = torch.LongTensor([max(0, len(s) - 1) for s in self.raw_sequence])

        # For loss calculation
        self.ntokens = torch.sum(self.lengths).item()

        self.has_article_end = vocab.w2i['</s>'] in [s[-1] for s in indices if len(s) > 0]
예제 #5
0
 def readVocs(self, datafile, corpus_name):
     print("Reading lines...")
     # Read the file and split into lines
     lines = open(datafile, encoding='utf-8'). \
         read().strip().split('\n')
     # Split every line into pairs and normalize
     pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
     voc = Vocab(corpus_name)
     return voc, pairs
예제 #6
0
    def __init__(self, weight_path, have_att=False):
        ENC_EMB_DIM = 256
        DEC_EMB_DIM = 256
        ENC_HID_DIM = 512
        DEC_HID_DIM = 512
        ENC_DROPOUT = 0.5
        DEC_DROPOUT = 0.5
        MAX_LEN = 46
        self.maxlen = MAX_LEN
        self.vocab = Vocab(alphabets)

        INPUT_DIM = self.vocab.__len__()
        OUTPUT_DIM = self.vocab.__len__()

        if have_att:
            self.model = Seq2Seq(input_dim=INPUT_DIM,
                                 output_dim=OUTPUT_DIM,
                                 encoder_embbeded=ENC_EMB_DIM,
                                 decoder_embedded=DEC_EMB_DIM,
                                 encoder_hidden=ENC_HID_DIM,
                                 decoder_hidden=DEC_HID_DIM,
                                 encoder_dropout=ENC_DROPOUT,
                                 decoder_dropout=DEC_DROPOUT)
        else:
            self.model = Seq2Seq_WithoutAtt(input_dim=INPUT_DIM,
                                            output_dim=OUTPUT_DIM,
                                            encoder_embbeded=ENC_EMB_DIM,
                                            decoder_embedded=DEC_EMB_DIM,
                                            encoder_hidden=ENC_HID_DIM,
                                            decoder_hidden=DEC_HID_DIM,
                                            encoder_dropout=ENC_DROPOUT,
                                            decoder_dropout=DEC_DROPOUT)

        self.load_weights(weight_path)
        if torch.cuda.is_available():
            self.device = "cuda"
            self.model.to('cuda')
        else:
            self.device = "cpu"

        print("Device: ", self.device)
        print("Loaded model")
예제 #7
0
파일: predict.py 프로젝트: royyoung388/srl
def convert_to_string(labels, label_vocab: Vocab, lengths: torch.Tensor):
    """

    :param labels: (batch * seq)
    :param label_vocab:
    :param lengths: IntTensor. batch size
    :return:
    """
    result = []
    for label, length in zip(labels, lengths):
        result.append(label_vocab.toToken(label)[:length])
    return result
예제 #8
0
def main():
    global args
    args = parse_args()
    vocab_file = os.path.join(args.dtree, 'snli_vocab_cased.txt')
    vocab = Vocab(filename=vocab_file)

    args.cuda = args.cuda and torch.cuda.is_available()
    device = torch.device("cuda:0" if args.cuda else "cpu")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = True

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    l_train_file = os.path.join(args.dtree, args.premise_train)
    r_train_file = os.path.join(args.dtree, args.hypothesis_train)
    label_train_file = os.path.join(args.dtree, args.label_train)

    l_dev_file = os.path.join(args.dtree, args.premise_dev)
    r_dev_file = os.path.join(args.dtree, args.hypothesis_dev)
    label_dev_file = os.path.join(args.dtree, args.label_dev)

    l_test_file = os.path.join(args.dtree, args.premise_test)
    r_test_file = os.path.join(args.dtree, args.hypothesis_test)
    label_test_file = os.path.join(args.dtree, args.label_test)

    l_train_squence_file = os.path.join(args.ctree, args.premise_train)
    r_train_squence_file = os.path.join(args.ctree, args.hypothesis_train)

    l_dev_squence_file = os.path.join(args.ctree, args.premise_dev)
    r_dev_squence_file = os.path.join(args.ctree, args.hypothesis_dev)

    l_test_squence_file = os.path.join(args.ctree, args.premise_test)
    r_test_squence_file = os.path.join(args.ctree, args.hypothesis_test)

    print(l_train_file, l_dev_file, l_test_file)
    print(r_train_file, r_dev_file, r_test_file)
    print(label_train_file, label_dev_file, label_test_file)

    # load SICK dataset splits
    train_file = os.path.join(args.data, 'train.pth')
    if os.path.isfile(train_file):
        train_dataset = torch.load(train_file)
    else:
        train_dataset = NLIdataset(premise_tree=l_train_file,
                                   hypothesis_tree=r_train_file,
                                   premise_seq=l_train_squence_file,
                                   hypothesis_seq=r_train_squence_file,
                                   label=label_train_file,
                                   vocab=vocab,
                                   num_classes=3,
                                   args=args)
        torch.save(train_dataset, train_file)
    if args.savedev == 1:
        dev_file = os.path.join(args.data, 'dev.pth')
        if os.path.isfile(dev_file):
            dev_dataset = torch.load(dev_file)
        else:
            dev_dataset = NLIdataset(premise_tree=l_dev_file,
                                     hypothesis_tree=r_dev_file,
                                     premise_seq=l_dev_squence_file,
                                     hypothesis_seq=r_dev_squence_file,
                                     label=label_dev_file,
                                     vocab=vocab,
                                     num_classes=3,
                                     args=args)
            torch.save(dev_dataset, dev_file)

        test_file = os.path.join(args.data, 'test.pth')
        if os.path.isfile(test_file):
            test_dataset = torch.load(test_file)
        else:
            test_dataset = NLIdataset(premise_tree=l_test_file,
                                      hypothesis_tree=r_test_file,
                                      premise_seq=l_test_squence_file,
                                      hypothesis_seq=r_test_squence_file,
                                      label=label_test_file,
                                      vocab=vocab,
                                      num_classes=3,
                                      args=args)
            torch.save(test_dataset, test_file)
    else:
        dev_dataset = NLIdataset(premise_tree=l_dev_file,
                                 hypothesis_tree=r_dev_file,
                                 premise_seq=l_dev_squence_file,
                                 hypothesis_seq=r_dev_squence_file,
                                 label=label_dev_file,
                                 vocab=vocab,
                                 num_classes=3,
                                 args=args)
        test_dataset = NLIdataset(premise_tree=l_test_file,
                                  hypothesis_tree=r_test_file,
                                  premise_seq=l_test_squence_file,
                                  hypothesis_seq=r_test_squence_file,
                                  label=label_test_file,
                                  vocab=vocab,
                                  num_classes=3,
                                  args=args)

    train_data_loader = DataLoader(train_dataset,
                                   batch_size=args.batchsize,
                                   shuffle=False)
    dev_data_loader = DataLoader(dev_dataset,
                                 batch_size=args.batchsize,
                                 shuffle=False)
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=args.batchsize,
                                  shuffle=False)

    # for data in train_data_loader:
    #     lsent, lgraph, rsent, rgraph, label = data
    #     print(label)
    #     break

    # # initialize model, criterion/loss_function, optimizer
    # model = TreeLSTMforNLI(
    #     vocab.size(),
    #     args.input_dim,
    #     args.mem_dim,
    #     args.hidden_dim,
    #     args.num_classes,
    #     args.sparse,
    #     args.freeze_embed)

    # for words common to dataset vocab and GLOVE, use GLOVE vectors
    # for other words in dataset vocab, use random normal vectors
    emb_file = os.path.join(args.data, 'snli_embed.pth')
    if os.path.isfile(emb_file):
        emb = torch.load(emb_file)
    else:
        # load glove embeddings and vocab
        glove_vocab, glove_emb = utils.load_word_vectors(
            os.path.join(args.glove, 'glove.840B.300d'))
        emb = torch.zeros(vocab.size(),
                          glove_emb.size(1),
                          dtype=torch.float,
                          device=device)
        emb.normal_(0, 0.05)
        # zero out the embeddings for padding and other special words if they are absent in vocab
        for idx, item in enumerate(['_PAD_', '_UNK_', '_BOS_', '_EOS_']):
            emb[idx].zero_()
        for word in vocab.labelToIdx.keys():
            if glove_vocab.getIndex(word):
                emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(
                    word)]
        torch.save(emb, emb_file)
    # plug these into embedding matrix inside model
    # model.emb.weight.data.copy_(emb)
    model = ESIM(vocab.size(),
                 args.input_dim,
                 args.mem_dim,
                 embeddings=emb,
                 dropout=0.5,
                 num_classes=args.num_classes,
                 device=device,
                 freeze=args.freeze_embed).to(device)
    criterion = nn.CrossEntropyLoss()
    model.to(device), criterion.to(device)
    if args.optim == 'adam':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=args.lr,
                               weight_decay=args.wd)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                         model.parameters()),
                                  lr=args.lr,
                                  weight_decay=args.wd)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.lr,
                              weight_decay=args.wd)

    trainer = Trainer(args, model, criterion, optimizer, device)

    best = -999.0
    best_loop = 0
    for epoch in range(args.epochs):
        train_loss = trainer.train(train_data_loader)
        train_loss, train_acc = trainer.test(train_data_loader)
        dev_loss, dev_acc = trainer.test(dev_data_loader)
        test_loss, test_acc = trainer.test(test_data_loader)

        print('==> Epoch {}, Train \tLoss: {}\tAcc: {}'.format(
            epoch, train_loss, train_acc))
        print('==> Epoch {}, Dev \tLoss: {}\tAcc: {}'.format(
            epoch, dev_loss, dev_acc))
        print('==> Epoch {}, Test \tLoss: {}\tAcc: {}'.format(
            epoch, test_loss, test_acc))

        if best < test_acc:
            best = test_acc
            best_loop = 0
            print('Get Improvement,Save Model, The best performence is %f' %
                  (best))
            checkpoint = {
                'model': trainer.model.state_dict(),
                'optim': trainer.optimizer,
                'acc': test_acc,
                'args': args,
                'epoch': epoch
            }
            print('==> New optimum found, checkpointing everything now...')
            torch.save(checkpoint,
                       '%s.pt' % os.path.join(args.save, args.expname))
        else:
            best_loop += 1
            if best_loop > args.patience:
                print('Early Stop,Best Acc:%f' % (best))
                break
예제 #9
0
파일: predict.py 프로젝트: royyoung388/srl
class Predictor(object):
    def __init__(self,
                 model_path,
                 word_vocab,
                 label_vocab,
                 word,
                 label,
                 device='cpu'):
        # load vocab
        self.word_vocab = Vocab(word_vocab)
        self.label_vocab = Vocab(label_vocab)

        self.word_vocab.unk_id = self.word_vocab.toID(UNK)
        # todo just for test
        self.label_vocab.unk_id = 1
        config.WORD_PAD_ID = self.word_vocab.toID(PAD)
        config.WORD_UNK_ID = self.word_vocab.toID(UNK)
        config.LABEL_PAD_ID = self.label_vocab.toID(PAD)
        pred_id = [self.label_vocab.toID('B-v')]

        self.device = torch.device(device)

        # load data
        dataset = DataReader(word, label, self.word_vocab, self.label_vocab)
        self.dataLoader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            # pin_memory=True,
            shuffle=False,
            collate_fn=Collate(pred_id, WORD_PAD_ID, LABEL_PAD_ID, False))

        self.model = DeepAttn(self.word_vocab.size(), self.label_vocab.size(),
                              feature_dim, model_dim, filter_dim)
        self.model.load_state_dict(torch.load(model_path, map_location=device))
        self.model.to(self.device)

    def save(self, path, labels):
        with open(path, 'a', encoding='utf-8') as f:
            for label in labels:
                f.write(' '.join(label))
                f.write('\n')

    def predict(self, save_path=None):
        start = time.time()
        # print('start predict')

        self.model.eval()
        with torch.no_grad():
            y_pred = []
            y_true = []
            for step, (xs, preds, ys, lengths) in enumerate(self.dataLoader):
                xs, preds, ys, lengths = xs.to(self.device), preds.to(
                    self.device), ys.to(self.device), lengths.to(self.device)

                y_true.extend(
                    convert_to_string(ys.squeeze().tolist(), self.label_vocab,
                                      lengths))

                labels = self.model.argmax_decode(xs, preds)
                labels[preds.ne(0)] = self.label_vocab.toID('B-v')
                labels = convert_to_string(labels.squeeze().tolist(),
                                           self.label_vocab, lengths)
                y_pred.extend(labels)

                # print('predict step: %d, time: %.2f M' % (step, (time.time() - start) / 60))
                # print('current F1 score: %.2f' % f1_score(y_true, y_pred))

            if save_path:
                self.save(save_path, y_pred)

            score = f1_score(y_true, y_pred)

            print('finish predict: %.2f M, F1 score: %.2f' %
                  ((time.time() - start) / 60, score))
            return score
예제 #10
0
    def _extra_init(self, loaded_batches: bool):
        self.rel_vocab = Vocab.from_dict(self._path / 'rel_names.pkl',
                                         mode='i2w')
        self.vocab: Dict[str, Vocab] = {
            "word": self.word_vocab,
            "rel": self.rel_vocab
        }

        self.max_unkrel = max(
            (-rel_typ - 3 for rel_typ in self.rel_vocab.i2w if rel_typ < -3),
            default=0)

        if self._use_fasttext:

            def _alias_path(name):
                path = Path(self._fasttext_model_path)
                return path.parent / (path.name + f'.{name}')

            # gather all entity aliases and compute fastText embeddings
            alias_dict_path = _alias_path('alias_dict.pkl')
            if alias_dict_path.exists():
                alias_dict: Dict[str, int] = loadpkl(alias_dict_path)
                loaded = True
            else:
                alias_dict = defaultdict(lambda: len(alias_dict))
                loaded = False
            if not loaded_batches:
                for dataset in self.data.values():
                    for example in dataset:
                        for idx, rel in enumerate(
                                example.relations):  # type: ignore
                            example.relations[
                                idx] = rel._replace(  # type: ignore
                                    obj_alias=[
                                        alias_dict[s] for s in rel.obj_alias
                                    ])
            if not alias_dict_path.exists():
                alias_dict = dict(alias_dict)
                savepkl(alias_dict, alias_dict_path)

            alias_vectors_path = _alias_path('alias_vectors.pt')
            if not alias_vectors_path.exists() or not loaded:
                import fastText
                ft_model = fastText.load_model(self._fasttext_model_path)
                alias_vectors = []
                alias_list = utils.reverse_map(alias_dict)
                for alias in utils.progress(alias_list,
                                            desc="Building fastText vectors",
                                            ascii=True,
                                            ncols=80):
                    vectors = [
                        ft_model.get_word_vector(w) for w in alias.split()
                    ]
                    vectors = np.sum(vectors, axis=0).tolist()
                    alias_vectors.append(vectors)
                alias_vectors = torch.tensor(alias_vectors)
                torch.save(alias_vectors, alias_vectors_path)

        if not loaded_batches and (self._exclude_entity_disamb
                                   or self._exclude_alias_disamb):
            # no need to do this if batches are loaded
            if self._exclude_entity_disamb:
                # gather training set stats
                self.entity_count_per_type = self.gather_entity_stats(
                    self.data['train'])

            for dataset in self.data.values():
                for idx in range(len(dataset)):
                    dataset[idx] = self.remove_ambiguity(
                        dataset[idx], self._exclude_entity_disamb,
                        self._exclude_alias_disamb)
예제 #11
0
    def __init__(self,
                 path: str,
                 batch_size: int,
                 vocab_dir: str,
                 bptt_size: int,
                 vocab_size: Optional[int],
                 min_freq: Optional[int] = None,
                 include_train: bool = True,
                 create_batches: bool = True,
                 use_only_first_section: bool = False,
                 unk_probs_path: Optional[str] = None,
                 use_upp: bool = False,
                 cache_batches: bool = True,
                 **_kwargs):
        self.batch_size = batch_size
        self.unk_probs = None
        self.bptt_size = bptt_size
        self._unk_probs_path = unk_probs_path
        self._include_train = include_train
        self._use_only_first_section = use_only_first_section
        self.vocab_size = vocab_size
        self.min_freq = min_freq

        self._path = path = Path(path)
        vocab_dir = Path(vocab_dir)

        if not vocab_dir.exists():
            vocab_dir.mkdir(parents=True)
        vocab_file_name = 'vocab'
        if vocab_size is not None:
            vocab_file_name += f'.{vocab_size}'
        if min_freq is not None:
            vocab_file_name += f'.freq{min_freq}'
        vocab_file_name += '.pt'
        self.vocab_path = vocab_path = vocab_dir / vocab_file_name

        loaded_batches = False
        if cache_batches and create_batches and vocab_path.exists(
        ) and include_train:
            # Try to load a cached version of the batches if possible; and do not load data
            loaded_batches = self._try_load_cache(path)
        if not loaded_batches:
            # Failed to load cached batches
            self.data = self.read_data(path)

        if not include_train or vocab_path.exists():
            self.word_vocab = torch.load(vocab_path)
            LOGGER.info(
                f"Word Vocabulary of size {len(self.word_vocab)} loaded from {vocab_path}."
            )
        else:
            train_corpus = [
                w for ex in self.data['train'] for w in ex.sentence
            ]  # type: ignore
            self.word_vocab = Vocab.from_corpus(train_corpus,
                                                min_freq=min_freq,
                                                max_vocab=vocab_size,
                                                lowercase=False)
            torch.save(self.word_vocab, vocab_path)
            LOGGER.info(
                f"Vocabulary of size {len(self.word_vocab)} constructed, saved to {vocab_path}."
            )
        self.vocab = self.word_vocab

        if unk_probs_path is not None:
            unk_probs, total_w2i = prepare_unkprob(unk_probs_path,
                                                   self.word_vocab,
                                                   uniform_unk=use_upp)
            unk_probs = torch.tensor(unk_probs,
                                     dtype=torch.float,
                                     requires_grad=False)
            self.unk_probs = unk_probs
            unk_probs[1] = 0  # <pad>
            self.total_w2i = total_w2i

        self._extra_init(loaded_batches)

        if not loaded_batches and create_batches:
            self.create_batches(batch_size, bptt_size)
            if cache_batches and include_train:
                self._save_cache(path)
예제 #12
0
    def __init__(self, alphabets_, list_ngram):

        self.vocab = Vocab(alphabets_)
        self.synthesizer = SynthesizeData(vocab_path="")
        self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split(
            list_ngram, test_size=0.1)
        print("Loaded data!!!")
        print("Total training samples: ", len(self.list_ngrams_train))
        print("Total valid samples: ", len(self.list_ngrams_valid))

        INPUT_DIM = self.vocab.__len__()
        OUTPUT_DIM = self.vocab.__len__()

        self.device = DEVICE
        self.num_iters = NUM_ITERS
        self.beamsearch = BEAM_SEARCH

        self.batch_size = BATCH_SIZE
        self.print_every = PRINT_PER_ITER
        self.valid_every = VALID_PER_ITER

        self.checkpoint = CHECKPOINT
        self.export_weights = EXPORT
        self.metrics = MAX_SAMPLE_VALID
        logger = LOG

        if logger:
            self.logger = Logger(logger)

        self.iter = 0

        self.model = Seq2Seq(input_dim=INPUT_DIM,
                             output_dim=OUTPUT_DIM,
                             encoder_embbeded=ENC_EMB_DIM,
                             decoder_embedded=DEC_EMB_DIM,
                             encoder_hidden=ENC_HID_DIM,
                             decoder_hidden=DEC_HID_DIM,
                             encoder_dropout=ENC_DROPOUT,
                             decoder_dropout=DEC_DROPOUT)

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer,
                                    total_steps=self.num_iters,
                                    pct_start=PCT_START,
                                    max_lr=MAX_LR)

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        self.train_gen = self.data_gen(self.list_ngrams_train,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=True)
        self.valid_gen = self.data_gen(self.list_ngrams_valid,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=False)

        self.train_losses = []

        # to device
        self.model.to(self.device)
        self.criterion.to(self.device)
예제 #13
0
class Trainer():
    def __init__(self, alphabets_, list_ngram):

        self.vocab = Vocab(alphabets_)
        self.synthesizer = SynthesizeData(vocab_path="")
        self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split(
            list_ngram, test_size=0.1)
        print("Loaded data!!!")
        print("Total training samples: ", len(self.list_ngrams_train))
        print("Total valid samples: ", len(self.list_ngrams_valid))

        INPUT_DIM = self.vocab.__len__()
        OUTPUT_DIM = self.vocab.__len__()

        self.device = DEVICE
        self.num_iters = NUM_ITERS
        self.beamsearch = BEAM_SEARCH

        self.batch_size = BATCH_SIZE
        self.print_every = PRINT_PER_ITER
        self.valid_every = VALID_PER_ITER

        self.checkpoint = CHECKPOINT
        self.export_weights = EXPORT
        self.metrics = MAX_SAMPLE_VALID
        logger = LOG

        if logger:
            self.logger = Logger(logger)

        self.iter = 0

        self.model = Seq2Seq(input_dim=INPUT_DIM,
                             output_dim=OUTPUT_DIM,
                             encoder_embbeded=ENC_EMB_DIM,
                             decoder_embedded=DEC_EMB_DIM,
                             encoder_hidden=ENC_HID_DIM,
                             decoder_hidden=DEC_HID_DIM,
                             encoder_dropout=ENC_DROPOUT,
                             decoder_dropout=DEC_DROPOUT)

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer,
                                    total_steps=self.num_iters,
                                    pct_start=PCT_START,
                                    max_lr=MAX_LR)

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        self.train_gen = self.data_gen(self.list_ngrams_train,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=True)
        self.valid_gen = self.data_gen(self.list_ngrams_valid,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=False)

        self.train_losses = []

        # to device
        self.model.to(self.device)
        self.criterion.to(self.device)

    def train_test_split(self, list_phrases, test_size=0.1):
        list_phrases = list_phrases
        train_idx = int(len(list_phrases) * (1 - test_size))
        list_phrases_train = list_phrases[:train_idx]
        list_phrases_valid = list_phrases[train_idx:]
        return list_phrases_train, list_phrases_valid

    def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True):
        dataset = AutoCorrectDataset(list_ngrams_np,
                                     transform_noise=synthesizer,
                                     vocab=vocab,
                                     maxlen=MAXLEN)

        shuffle = True if is_train else False
        gen = DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=shuffle,
                         drop_last=False)

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        src, tgt = batch['src'], batch['tgt']
        src, tgt = src.transpose(1, 0), tgt.transpose(
            1, 0)  # batch x src_len -> src_len x batch

        outputs = self.model(
            src, tgt)  # src : src_len x B, outpus : B x tgt_len x vocab

        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))  # flatten(0, 1)

        tgt_output = tgt.transpose(0, 1).reshape(
            -1)  # flatten()   # tgt: tgt_len xB , need convert to B x tgt_len

        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def train(self):
        print("Begin training from iter: ", self.iter)
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        best_acc = -1

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)

            if self.iter % self.valid_every == 0:
                val_loss, preds, actuals, inp_sents = self.validate()
                acc_full_seq, acc_per_char, cer = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, cer)
                print(info)
                print("--- Sentence predict ---")
                for pred, inp, label in zip(preds, inp_sents, actuals):
                    infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format(
                        pred, inp, label)
                    print(infor_predict)
                    self.logger.log(infor_predict)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq
                self.save_checkpoint(self.checkpoint)

    def validate(self):
        self.model.eval()

        total_loss = []
        max_step = self.metrics / self.batch_size
        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                src, tgt = batch['src'], batch['tgt']
                src, tgt = src.transpose(1, 0), tgt.transpose(1, 0)

                outputs = self.model(src, tgt, 0)  # turn off teaching force

                outputs = outputs.flatten(0, 1)
                tgt_output = tgt.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                preds, actuals, inp_sents, probs = self.predict(5)

                del outputs
                del loss
                if step > max_step:
                    break

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss, preds[:3], actuals[:3], inp_sents[:3]

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        inp_sents = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['src'], self.model)
                prob = None
            else:
                translated_sentence, prob = translate(batch['src'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt'].tolist())
            inp_sent = self.vocab.batch_decode(batch['src'].tolist())

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)
            inp_sents.extend(inp_sent)

            if sample is not None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, inp_sents, prob

    def precision(self, sample=None):

        pred_sents, actual_sents, _, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='per_char')
        cer = compute_accuracy(actual_sents, pred_sents, mode='CER')

        return acc_full_seq, acc_per_char, cer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16):

        pred_sents, actual_sents, img_files, probs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}

    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())

                n += 1
                if n >= sample:
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {
            'iter': self.iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'scheduler': self.scheduler.state_dict()
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape, required {} but found {}'.format(
                    name, param.shape, state_dict[name].shape))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):

        src = batch['src'].to(self.device, non_blocking=True)
        tgt = batch['tgt'].to(self.device, non_blocking=True)

        batch = {'src': src, 'tgt': tgt}

        return batch
예제 #14
0
class Predictor:
    def __init__(self, weight_path, have_att=False):
        ENC_EMB_DIM = 256
        DEC_EMB_DIM = 256
        ENC_HID_DIM = 512
        DEC_HID_DIM = 512
        ENC_DROPOUT = 0.5
        DEC_DROPOUT = 0.5
        MAX_LEN = 46
        self.maxlen = MAX_LEN
        self.vocab = Vocab(alphabets)

        INPUT_DIM = self.vocab.__len__()
        OUTPUT_DIM = self.vocab.__len__()

        if have_att:
            self.model = Seq2Seq(input_dim=INPUT_DIM,
                                 output_dim=OUTPUT_DIM,
                                 encoder_embbeded=ENC_EMB_DIM,
                                 decoder_embedded=DEC_EMB_DIM,
                                 encoder_hidden=ENC_HID_DIM,
                                 decoder_hidden=DEC_HID_DIM,
                                 encoder_dropout=ENC_DROPOUT,
                                 decoder_dropout=DEC_DROPOUT)
        else:
            self.model = Seq2Seq_WithoutAtt(input_dim=INPUT_DIM,
                                            output_dim=OUTPUT_DIM,
                                            encoder_embbeded=ENC_EMB_DIM,
                                            decoder_embedded=DEC_EMB_DIM,
                                            encoder_hidden=ENC_HID_DIM,
                                            decoder_hidden=DEC_HID_DIM,
                                            encoder_dropout=ENC_DROPOUT,
                                            decoder_dropout=DEC_DROPOUT)

        self.load_weights(weight_path)
        if torch.cuda.is_available():
            self.device = "cuda"
            self.model.to('cuda')
        else:
            self.device = "cpu"

        print("Device: ", self.device)
        print("Loaded model")

    def predict_ngram(self, ngram, beamsearch=False):
        '''
          Denoise for ngram
          ngram: text
        '''
        src = self.preprocessing(ngram)
        src = src.unsqueeze(0)
        src = src.to(self.device)

        if beamsearch:
            translated_sentence = batch_translate_beam_search(src, self.model)
            prob = None
        else:
            translated_sentence, prob = translate(src, self.model)
        # print(translated_sentence)
        pred_sent = self.vocab.decode(translated_sentence.tolist()[0])

        return pred_sent

    def spelling_correct(self, sentence):
        # Remove characters that out of vocab
        sentence = re.sub(
            r'[^aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&'
            '()*+,-./:;<=>?@[\]^_`{|}~ ]', "", sentence)

        # Extract pharses
        phrases, phrases_all, index_sent_dict = self.extract_phrases(sentence)

        correct_phrases = []
        for phrase in phrases:
            ngrams = list(self.gen_ngrams(phrase, n=NGRAM))
            correct_ngram_str_array = []
            for ngram_list in ngrams:
                ngram_str = " ".join(ngram_list)

                correct_ngram_str = self.predict_ngram(ngram_str)
                correct_ngram_str_array.append(correct_ngram_str)
            correct_phrase = self.reconstruct_from_ngrams(
                correct_ngram_str_array)
            correct_phrases.append(correct_phrase)
        correct_sentence = self.decode_phrases(correct_phrases, phrases_all,
                                               index_sent_dict)
        return correct_sentence

    def reconstruct_from_ngrams(self, predicted_ngrams):
        '''
        predicted_ngrams: list of ngram_str
        '''

        candidates = [
            Counter() for _ in range(len(predicted_ngrams) + NGRAM - 1)
        ]
        for nid, ngram in (enumerate(predicted_ngrams)):
            tokens = re.split(r' +', ngram)
            for wid, word in enumerate(tokens):
                candidates[nid + wid].update([word])
        # print(candidates)
        output = ' '.join(
            c.most_common(1)[0][0] for c in candidates if len(c) != 0)
        return output

    def extract_phrases(self, text):
        pattern = r'\w[\w ]*|\s\W+|\W+'

        phrases_all = re.findall(pattern, text)

        index_sent_dict = {}
        phrases_str = []
        for ind, phrase in enumerate(phrases_all):
            if not re.match(r'[!"#$%&'
                            '()*+,-./:;<=>?@[\]^_`{|}~]', phrase.strip()):
                phrases_str.append(phrase.strip())
                index_sent_dict[ind] = phrase

        return phrases_str, phrases_all, index_sent_dict

    def decode_phrases(self, correct_phrases, phrases, index_sent_dict):
        # correct_phrases = ['lê văn', 'Hoàng', 'Hehe', 'g']
        sentence_correct = phrases.copy()
        for i, idx_sent in enumerate(index_sent_dict.keys()):
            sentence_correct[idx_sent] = correct_phrases[i]

        # print(sentence_correct)
        return "".join(sentence_correct)

    def preprocessing(self, sentence):

        # Encode characters
        noise_sent_idxs = self.vocab.encode(sentence)

        # Padding to MAXLEN
        src_len = len(noise_sent_idxs)
        if self.maxlen - src_len < 0:
            noise_sent_idxs = noise_sent_idxs[:self.maxlen]
            src_len = len(noise_sent_idxs)
            print("Over length in src")
        src = np.concatenate(
            (noise_sent_idxs, np.zeros(self.maxlen - src_len, dtype=np.int32)))

        return torch.LongTensor(src)

    def gen_ngrams(self, sent, n=5):
        tokens = sent.split()

        if len(tokens) < n:
            return [tokens]

        return nltk.ngrams(sent.split(), n)

    def load_weights(self, filename):
        state_dict = torch.load(filename, map_location=torch.device('cpu'))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape, required {} but found {}'.format(
                    name, param.shape, state_dict[name].shape))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)
예제 #15
0
model = TreeLSTM(2, 2)

# criterion = nn.CrossEntropyLoss()
input = torch.from_numpy(input)
graph = torch.from_numpy(graph)
# print(input.size())
# print(graph.size())
input = torch.unsqueeze(input, 0).repeat(1, 1, 1)
graph = torch.unsqueeze(graph, 0).repeat(1, 1, 1)
input, graph = Variable(input), Variable(graph)

global args
args = parse_args()
vocab_file = os.path.join(args.dtree, 'snli_vocab_cased.txt')
vocab = Vocab(filename=vocab_file)

# args.cuda = args.cuda and torch.cuda.is_available()
# device = torch.device("cuda:0" if args.cuda else "cpu")

l_dev_file = os.path.join(args.dtree, args.premise_dev)
r_dev_file = os.path.join(args.dtree, args.hypothesis_dev)
label_dev_file = os.path.join(args.dtree, args.label_dev)

l_dev_squence_file = os.path.join(args.ctree, args.premise_dev)
r_dev_squence_file = os.path.join(args.ctree, args.hypothesis_dev)

l_test_file = os.path.join(args.dtree, args.premise_test)
r_test_file = os.path.join(args.dtree, args.hypothesis_test)
label_test_file = os.path.join(args.dtree, args.label_test)
예제 #16
0
                        action='store_const',
                        const='cuda',
                        default='cpu',
                        help=msg)
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()

    # result save path
    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    # load vocab
    word_vocab = Vocab(args.word_vocab)
    word_vocab.unk_id = word_vocab.toID(UNK)
    label_vocab = Vocab(args.label_vocab)
    config.WORD_PAD_ID = word_vocab.toID(PAD)
    config.WORD_UNK_ID = word_vocab.toID(UNK)
    config.LABEL_PAD_ID = label_vocab.toID(PAD)
    pred_id = [label_vocab.toID('B-v')]

    device = torch.device(args.cuda)

    # load data
    dataset = DataReader(args.word, args.label, word_vocab, label_vocab)
    dataLoader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            pin_memory=True,
예제 #17
0
        ys = [v[1] for v in batch]
        # 获得每个样本的序列长度
        lengths = [len(v) for v in xs]
        seq_lengths = torch.IntTensor(lengths)
        max_len = max(lengths)
        # 每个样本都padding到当前batch的最大长度
        xs = torch.LongTensor([pad_tensor(v, max_len, self.word_pad) for v in xs])
        ys = torch.LongTensor([pad_tensor(v, max_len, self.label_pad) for v in ys])
        # 把xs和ys按照序列长度从大到小排序
        if self.sort:
            seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
            xs = xs[perm_idx]
            ys = ys[perm_idx]
        # predicate mask. set to 1 when the label is 'B-v', else 0
        preds = torch.LongTensor([[1 if id in self.pred_id else 0 for id in v] for v in ys])
        return xs, preds, ys, seq_lengths

    def __call__(self, batch):
        return self._collate(batch)


if __name__ == '__main__':
    word_vocab = Vocab('../../data/train/word_vocab.txt', 1)
    label_vocab = Vocab('../../data/train/label_vocab.txt')
    pred_id = [label_vocab.toID('B-v'), label_vocab.toID('I-v')]

    dataset = DataReader('../../data/train/word.txt', '../../data/train/label.txt', word_vocab, label_vocab)
    dataLoader = DataLoader(dataset=dataset, batch_size=32, num_workers=4, collate_fn=Collate(pred_id, 0, 0))
    for xs, preds, ys, lengths in dataLoader:
        print(ys)