예제 #1
0
    def _get_iterator(self, path):
        """
        Creates an iterator object from a text file.

        Args:
            path(str): path to text file to process
        Returns:
            data_iter(inputters.OrderedIterator): iterator object
        """
        # Create dataset object
        data = inputters.build_dataset(fields=self.fields,
                                       data_type='text',
                                       src_path=path,
                                       tgt_path=None,
                                       src_dir='',
                                       use_filter_pred=False)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self.gpu,
            batch_size=self.similar_pairs.batch_size,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)
        return data_iter
예제 #2
0
def build_save_dataset(corpus_type, fields, opt):
    """ Building and saving the dataset """
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        corpus = opt.train_dir
    else:
        corpus = opt.valid_dir

    dataset = inputters.build_dataset(
        fields,
        data_path=corpus,
        data_type=opt.data_type,
        total_token_length=opt.total_token_length,
        src_seq_length=opt.src_seq_length,
        src_sent_length=opt.src_sent_length,
        seq_length_trunc=opt.seq_length_trunc)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type)
    logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file))
    torch.save(dataset, pt_file)

    return pt_file
예제 #3
0
def build_save_dataset(corpus_type, fields, src_corpus, tgt_corpus, savepath,
                       args):
    """ Building and saving the dataset """
    assert corpus_type in ['train', 'dev', 'test']
    dataset = inputters.build_dataset(fields,
                                      data_type='text',
                                      src_path=src_corpus,
                                      tgt_path=tgt_corpus,
                                      src_dir='',
                                      src_seq_length=args.max_src_len,
                                      tgt_seq_length=args.max_tgt_len,
                                      src_seq_length_trunc=0,
                                      tgt_seq_length_trunc=0,
                                      dynamic_dict=True)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    for i in range(len(dataset)):
        if i % 500 == 0:
            print(i)
        setattr(dataset.examples[i], 'graph',
                myutils.str2graph(dataset.examples[i].src))

    pt_file = "{:s}/{:s}.pt".format(savepath, corpus_type)
    # torch.save(dataset, pt_file)
    with open(pt_file, 'wb') as f:
        pickle.dump(dataset, f)
    return [pt_file]
예제 #4
0
    def generate_vectors(self, list_of_sentences, batch_size=1, cuda=False):
        """
        list_of_sentences: [str]
        batch_size: int
        :return [np.array] numpy vectors of sentences in the same order
        """
        unique_filename = str(uuid.uuid4())

        # delete repeating tmp files
        tmp_files = os.listdir(pjoin(self.temp_dir, "l2e"))

        if len(tmp_files) > 10:
            for f_n in tmp_files:
                os.remove(pjoin(self.temp_dir, "l2e", f_n))

        with open(
                pjoin(self.temp_dir, "l2e", '{}.txt'.format(unique_filename)),
                'w') as f:
            for s in list_of_sentences:
                f.write(s.strip() + '\n')

        data = inputters.build_dataset(
            self.fields,
            src_path=pjoin(self.temp_dir, "l2e",
                           '{}.txt'.format(unique_filename)),
            data_type='text',
            use_filter_pred=False)  # src_seq_length=50, dynamic_dict=False)

        if cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        all_vecs = []

        for batch in data_iter:
            # translation model just translates, here we generate vectors instead
            src, enc_states, memory_bank, src_lengths = _run_encoder(
                self.model, batch, 'text')
            # enc_states[0]: (layer_size, batch_size, hid)
            reshaped_hid_states = enc_states[0].reshape(
                batch_size, self.model_opt.enc_layers,
                self.model_opt.enc_rnn_size)
            # reshaped_hid_states: (batch_size, layer_size, hid)
            # we only append the 2nd layer
            all_vecs.append(reshaped_hid_states[:, -1, :].data.cpu().numpy())

        all_vecs = np.vstack(all_vecs)

        return all_vecs
예제 #5
0
def build_save_dataset(corpus_type, fields, opt):
    assert corpus_type in ['train', 'valid', 'comp']

    if corpus_type == 'train':
        src = opt.train_src
        tgt = opt.train_tgt
    elif corpus_type == 'valid':
        src = opt.valid_src
        tgt = opt.valid_tgt
    else:
        src = opt.comp_train_src
        tgt = opt.comp_train_tgt

    logger.info("Reading source and target files: %s %s." % (src, tgt))
    src_len = _write_temp_shard_files(src, fields, corpus_type, opt.shard_size)
    tgt_len = _write_temp_shard_files(tgt, fields, corpus_type, opt.shard_size)
    assert src_len == tgt_len, "Source and target should be the same length"

    src_shards = sorted(glob.glob(src + '.*.txt'))
    tgt_shards = sorted(glob.glob(tgt + '.*.txt'))
    shard_pairs = zip(src_shards, tgt_shards)
    dataset_paths = []

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Building shard %d." % i)
        dataset = inputters.build_dataset(
            fields, opt.data_type,
            src_path=src_shard,
            tgt_path=tgt_shard,
            src_dir=opt.src_dir,
            src_seq_len=opt.src_seq_length,
            tgt_seq_len=opt.tgt_seq_length,
            src_seq_length_trunc=opt.src_seq_length_trunc,
            tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
            dynamic_dict=opt.dynamic_dict,
            sample_rate=opt.sample_rate,
            window_size=opt.window_size,
            window_stride=opt.window_stride,
            window=opt.window,
            image_channel_size=opt.image_channel_size
        )

        data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i)
        dataset_paths.append(data_path)

        logger.info(" * saving %sth %s data shard to %s."
                    % (i, corpus_type, data_path))

        dataset.save(data_path)

        os.remove(src_shard)
        os.remove(tgt_shard)
        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return dataset_paths
예제 #6
0
def build_save_dataset(corpus_type, fields, opt):
    """ Building and saving the dataset """
    assert corpus_type in ['train', 'valid', 'monitor']

    if corpus_type == 'train':
        src_corpus = [opt.train_src]
        tgt_corpus = [opt.train_tgt]
    elif corpus_type == 'valid':
        src_corpus = [opt.valid_src]
        tgt_corpus = [opt.valid_tgt]
    else:
        assert len(opt.monitor_src) == len(opt.monitor_tgt)
        src_corpus = opt.monitor_src
        tgt_corpus = opt.monitor_tgt

    pt_files = []
    for i, (src, tgt) in enumerate(zip(src_corpus, tgt_corpus)):
        if "monitor" in corpus_type:
            fname = src.split("/" if "/" in src else "\\")[-1].split(
                ".")[0].replace("_src", "")
            corpus_type = "monitor_{}".format(fname)

        if (opt.shard_size > 0):
            pt_file = build_save_in_shards_using_shards_size(
                src, tgt, fields, corpus_type, opt)
            pt_files.extend(pt_file)

        else:
            # For data_type == 'img' or 'audio', currently we don't do
            # preprocess sharding. We only build a monolithic dataset.
            # But since the interfaces are uniform, it would be not hard
            # to do this should users need this feature.
            dataset = inputters.build_dataset(
                fields,
                opt.data_type,
                src_path=src,
                tgt_path=tgt,
                src_dir=opt.src_dir,
                src_seq_length=opt.src_seq_length,
                tgt_seq_length=opt.tgt_seq_length,
                src_seq_length_trunc=opt.src_seq_length_trunc,
                tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
                dynamic_dict=opt.dynamic_dict,
                sample_rate=opt.sample_rate,
                window_size=opt.window_size,
                window_stride=opt.window_stride,
                window=opt.window,
                image_channel_size=opt.image_channel_size)

            # We save fields in vocab.pt seperately, so make it empty.
            dataset.fields = []

            pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type)
            logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file))
            torch.save(dataset, pt_file)
            pt_files.append(pt_file)

    return pt_files
예제 #7
0
def build_save_dataset(corpus_type, fields, opt):
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        src = opt.train_src
        tgt = opt.train_tgt
        ans = opt.train_ans
    else:
        src = opt.valid_src
        tgt = opt.valid_tgt
        ans = opt.valid_ans

    logger.info("Reading source answer and target files: %s %s %s." %
                (src, ans, tgt))

    src_shards = split_corpus(src, opt.shard_size)
    tgt_shards = split_corpus(tgt, opt.shard_size)
    ans_shards = split_corpus(ans, opt.shard_size)

    shard_pairs = zip(src_shards, tgt_shards, ans_shards)
    dataset_paths = []

    for i, (src_shard, tgt_shard, ans_shard) in enumerate(shard_pairs):
        assert len(src_shard) == len(tgt_shard) == len(ans_shard)
        logger.info("Building shard %d." % i)
        dataset = inputters.build_dataset(
            fields,
            opt.data_type,
            src=src_shard,
            tgt=tgt_shard,
            ans=ans_shard,
            src_dir=opt.src_dir,
            src_seq_len=opt.src_seq_length,
            tgt_seq_len=opt.tgt_seq_length,
            ans_seq_len=opt.ans_seq_length,
            sample_rate=opt.sample_rate,
            window_size=opt.window_size,
            window_stride=opt.window_stride,
            window=opt.window,
            use_filter_pred=corpus_type == 'train' or opt.filter_valid)

        data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i)
        dataset_paths.append(data_path)

        logger.info(" * saving %sth %s data shard to %s." %
                    (i, corpus_type, data_path))

        dataset.save(data_path)

        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return dataset_paths
예제 #8
0
def build_save_dataset(corpus_type, fields, opt):
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        src = opt.train_src
        tgt = opt.train_tgt
    else:
        src = opt.valid_src
        tgt = opt.valid_tgt

    logger.info("Reading source and target files: %s %s." % (src, tgt))

    src_shards = split_corpus(src, opt.shard_size)
    tgt_shards = split_corpus(tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)
    dataset_paths = []

    total_valid_ex_num = 0
    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        assert len(src_shard) == len(tgt_shard)
        logger.info("Building shard %d." % i)
        dataset = inputters.build_dataset(
            fields,
            opt.data_type,
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            src_seq_len=opt.src_seq_length,
            tgt_seq_len=opt.tgt_seq_length,
            sample_rate=opt.sample_rate,
            window_size=opt.window_size,
            window_stride=opt.window_stride,
            window=opt.window,
            image_channel_size=opt.image_channel_size,
            use_filter_pred=corpus_type == 'train' or opt.filter_valid,
            src_seq_min_length=opt.src_seq_min_length,
            tgt_seq_min_length=opt.tgt_seq_min_length)

        data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i)
        dataset_paths.append(data_path)

        logger.info(" * saving %sth %s data shard to %s. Example number: %d" %
                    (i, corpus_type, data_path, len(dataset.examples)))
        total_valid_ex_num += len(dataset.examples)
        dataset.save(data_path)

        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    logger.info(" * Total Example number: %d" % (total_valid_ex_num))
    return dataset_paths
예제 #9
0
def build_save_dataset(corpus_type, fields, opt):
    """ Building and saving the dataset """
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        src_corpus = opt.train_src
        tgt_corpus = opt.train_tgt
        src_ref_corpus = opt.train_ref_src
        tgt_ref_corpus = opt.train_ref_tgt
    else:
        src_corpus = opt.valid_src
        tgt_corpus = opt.valid_tgt
        src_ref_corpus = opt.valid_ref_src
        tgt_ref_corpus = opt.valid_ref_tgt

    if (opt.shard_size > 0):
        return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus,
                                                      src_ref_corpus,
                                                      tgt_ref_corpus, fields,
                                                      corpus_type, opt)

    # For data_type == 'img' or 'audio', currently we don't do
    # preprocess sharding. We only build a monolithic dataset.
    # But since the interfaces are uniform, it would be not hard
    # to do this should users need this feature.
    dataset = inputters.build_dataset(
        fields,
        opt.data_type,
        src_path=src_corpus,
        tgt_path=tgt_corpus,
        src_ref_path=src_ref_corpus,
        tgt_ref_path=tgt_ref_corpus,
        src_dir=opt.src_dir,
        src_seq_length=opt.src_seq_length,
        tgt_seq_length=opt.tgt_seq_length,
        src_seq_length_trunc=opt.src_seq_length_trunc,
        tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
        dynamic_dict=opt.dynamic_dict,
        sample_rate=opt.sample_rate,
        window_size=opt.window_size,
        window_stride=opt.window_stride,
        window=opt.window,
        image_channel_size=opt.image_channel_size)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type)
    logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file))
    torch.save(dataset, pt_file)

    return [pt_file]
예제 #10
0
    def get_encodings(self,
                      src_path=None,
                      src_data_iter=None,
                      tgt_path=None,
                      tgt_data_iter=None,
                      src_dir=None,
                      batch_size=None,
                      attn_debug=False):
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            src_path=src_path,  #PATH TO INPUT FILE
            src_data_iter=src_data_iter,  #NONE
            tgt_path=tgt_path,  #NONE
            tgt_data_iter=tgt_data_iter,  #NONE
            src_dir=src_dir,  # empty string ""
            sample_rate=self.sample_rate,
            window_size=self.window_size,
            window_stride=self.window_stride,
            window=self.window,
            use_filter_pred=self.use_filter_pred)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        all_encodings = []

        for batch in data_iter:
            batch_data = self.Encode(batch, data)
            all_encodings.append(batch_data)
            print(batch_data)
        return all_encodings
예제 #11
0
def build_save_vectors(src_corpus, tgt_corpus, fields, corpus_type, opt):
    """
    Divide src_corpus and tgt_corpus into smaller multiples
    src_copus and tgt corpus files, then build shards, each
    shard will have opt.shard_size samples except last shard.

    The reason we do this is to avoid taking up too much memory due
    to sucking in a huge corpus file.
    """

    ret_list = []

    dataset = inputters.build_dataset(
        fields,
        opt.data_type,
        src_path=src_corpus,
        tgt_path=tgt_corpus,
        src_dir=opt.src_dir,
        src_seq_length=opt.src_seq_length,
        tgt_seq_length=opt.tgt_seq_length,
        src_seq_length_trunc=opt.src_seq_length_trunc,
        tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
        dynamic_dict=opt.dynamic_dict,
        sample_rate=opt.sample_rate,
        window_size=opt.window_size,
        window_stride=opt.window_stride,
        window=opt.window,
        image_channel_size=opt.image_channel_size)

    pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, 0)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    logger.info(" * saving %sth %s data shard to %s." %
                (0, corpus_type, pt_file))
    torch.save(dataset, pt_file)

    ret_list.append(pt_file)
    del dataset.examples
    gc.collect()
    del dataset
    gc.collect()

    return ret_list
예제 #12
0
def run_one(param):
    index, src, opt, fields, tgt_list, condition_corpus, corpus_type = param
    dataset = inputters.build_dataset(
        fields,
        opt.data_type,
        src_path=src,
        tgt_path=tgt_list[index],
        src_dir=opt.src_dir,
        src_seq_length=opt.src_seq_length,
        tgt_seq_length=opt.tgt_seq_length,
        src_seq_length_trunc=opt.src_seq_length_trunc,
        tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
        dynamic_dict=opt.dynamic_dict,
        sample_rate=opt.sample_rate,
        window_size=opt.window_size,
        window_stride=opt.window_stride,
        window=opt.window,
        image_channel_size=opt.image_channel_size)

    pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []
    if condition_corpus:
        # 加载条件
        with open(condition_corpus) as f:
            target_condition = [int(s.rstrip()) for s in f.readlines()]

    tmp_example = []
    _ = [parrel_func(e, opt.with_3d_confomer) for e in dataset.examples]
    for cond, result in zip(target_condition, dataset.examples):
        if getattr(result, 'graph') is not None:
            if condition_corpus:
                setattr(result, 'condition_target', cond)
            tmp_example.append(result)

    dataset.examples = tmp_example

    with open(pt_file, 'wb') as f:
        pickle.dump(dataset, f)

    os.remove(src)
    os.remove(tgt_list[index])
    return pt_file
예제 #13
0
    def encode_seq(self, src, tgt=None, src_dir=None, batch_size=None):
        assert src is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")

        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            src=src,
            tgt=tgt,
            src_dir=src_dir)

        cur_device = "cuda" if self.cuda else "cpu"

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=cur_device,
            batch_size=batch_size,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False
        )

        all_sent_vecs = []
        with torch.no_grad():
            for i, batch in enumerate(data_iter):
                batch_size = batch.batch_size

                # Encoder forward.
                src, enc_states, memory_bank, src_lengths = self._run_encoder(batch, data.data_type)
                # memory_bank (seq_lengths, batch_size, hidden_size)
                sent_vec_batch = memory_bank.mean(dim=0).cpu().numpy()
                np.savetxt(self.outfile, sent_vec_batch, fmt='%.10e')

                if (i + 1) % 10 == 0:
                    print(".", end="", flush=True)
                if (i + 1) % 100 == 0:
                    print((i + 1)*batch_size, end="", flush=True)
예제 #14
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  src_length=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False,
                  search_mode=0,
                  threshold=0,
                  ref_path=None):
        assert src_data_iter is not None or src_path is not None
        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            src_path=src_path,
            src_data_iter=src_data_iter,
            src_seq_length_trunc=src_length,
            tgt_path=tgt_path,
            tgt_data_iter=tgt_data_iter,
            src_dir=src_dir,
            sample_rate=self.sample_rate,
            window_size=self.window_size,
            window_stride=self.window_stride,
            window=self.window,
            use_filter_pred=self.use_filter_pred,
            ref_path=['%s.%d' % (ref_path, r)
                      for r in range(self.refer)] if self.refer else None,
            ref_seq_length_trunc=self.max_sent_length,
            ignore_unk=False)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"
        if self.refer:
            for i in range(self.refer):
                data.fields['ref%d' % i].vocab = data.fields['src'].vocab

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        if search_mode == 2:
            all_predictions = self.search(data_iter,
                                          data,
                                          src_path,
                                          train=False,
                                          threshold=threshold)
            for i in all_predictions:
                self.out_file.write(i)
                self.out_file.flush()
            return

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        for batch in data_iter:
            batch_data = self.translate_batch(batch,
                                              data,
                                              fast=True,
                                              attn_debug=False)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                # self.out_file.write('\n'.join(n_best_preds) + '\n')
                # self.out_file.flush()
        if search_mode == 1:
            sim_predictions = self.search(data_iter, data, src_path, threshold)
            for i in range(len(sim_predictions)):
                if not sim_predictions[i]:
                    self.out_file.write('\n'.join(all_predictions[i]) + '\n')
                    self.out_file.flush()
                else:
                    self.out_file.write(sim_predictions[i])
                    self.out_file.flush()
        else:
            for i in all_predictions:
                self.out_file.write('\n'.join(i) + '\n')
                self.out_file.flush()
        return all_scores, all_predictions
예제 #15
0
    def index_documents(
        self,
        src_path=None,
        src_data_iter=None,
        tgt_path=None,
        tgt_data_iter=None,
        src_dir=None,
        batch_size=None,
    ):
        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            src_path=src_path,
            src_data_iter=src_data_iter,
            src_seq_length_trunc=self.max_sent_length,
            tgt_path=tgt_path,
            tgt_data_iter=tgt_data_iter,
            src_dir=src_dir,
            sample_rate=self.sample_rate,
            window_size=self.window_size,
            window_stride=self.window_stride,
            window=self.window,
            use_filter_pred=self.use_filter_pred,
            ignore_unk=True)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        doc_feats = []
        shard = 1
        for batch in data_iter:

            # Encoder forward.
            src = inputters.make_features(batch, 'src', data.data_type)
            _, src_lengths = batch.src
            enc_states, memory_bank, _ = self.model.encoder(src, src_lengths)
            feature = torch.max(memory_bank, 0)[0]
            _, recover_indices = torch.sort(batch.indices, descending=False)
            feature = feature[recover_indices]
            doc_feats.append(feature)
            if len(doc_feats) % 1250 == 0:
                print('saving shard %d' % shard)
                doc_feats = torch.cat(doc_feats)
                torch.save(
                    doc_feats, '{}/indexes/codev{}.pt'.format(
                        '/'.join(src_path.split('/')[:2]), shard))

                doc_feats = []
                shard += 1
        if doc_feats:
            doc_feats = torch.cat(doc_feats)
            torch.save(
                doc_feats, '{}/indexes/codev{}.pt'.format(
                    '/'.join(src_path.split('/')[:2]), shard))
            print('done.')
예제 #16
0
    def translate(
        self,
        src,
        tgt=None,
        src_dir=None,
        batch_size=None,
        attn_debug=False,
        data_iter=None
    ):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            tgt_path (str): filepath of target data or None
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")

        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            src=src,
            src_reader=self.src_reader,
            tgt=tgt,
            tgt_reader=self.tgt_reader,
            src_dir=src_dir,
            use_filter_pred=self.use_filter_pred, bert=self.opt.bert, morph=self.opt.korean_morphs
        )

        cur_device = "cuda" if self.cuda else "cpu"

        # data_iter = inputters.OrderedIterator(
        #     dataset=data,
        #     device=cur_device,
        #     batch_size=batch_size,
        #     train=False,
        #     sort=False,
        #     sort_within_batch=True,
        #     shuffle=False
        # )
        builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt
        )

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch_ in data_iter:
            batch = bbbb(batch_)
            # batch_data = self.translate_batch(
            #     batch, data.src_vocabs, attn_debug, fast=self.fast
            # )
            batch_data = self.translate_batch(
                batch, batch.dataset.src_vocabs, attn_debug, fast=self.fast
            )
            # batch_data = self.translate_batch(
            #     batch, data.src_vocabs, attn_debug, fast=self.fast
            # )
            translations = builder.from_batch(batch_data)
            return translations
            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [" ".join(pred)
                                for pred in trans.pred_sents[:self.n_best]]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))
                        print(list(trans.attns[0].max(1)[1].cpu().detach().numpy()))
                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" % (
                total_time / len(all_predictions)))
            self._log("Tokens per second: %f" % (
                pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
예제 #17
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False,
                  intervention=None,
                  out_file=None):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(self.fields,
                                       self.data_type,
                                       src_path=src_path,
                                       src_data_iter=src_data_iter,
                                       tgt_path=tgt_path,
                                       tgt_data_iter=tgt_data_iter,
                                       src_dir=src_dir,
                                       sample_rate=self.sample_rate,
                                       window_size=self.window_size,
                                       window_stride=self.window_stride,
                                       window=self.window,
                                       use_filter_pred=self.use_filter_pred)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(
            dataset=data, device=cur_device,
            batch_size=batch_size, train=False, sort=False,
            sort_within_batch=True, shuffle=False)

        builder = onmt.translate.TranslationBuilder(
            data, self.fields,
            self.n_best, self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []
        all_dumped_layers = []

        for batch in data_iter:
            if self.dump_layers != '':
                batch_data, dumped_layers = self.translate_batch(batch,
                        data, intervention=intervention)

                # Get the correct order of sentences so that we can dump in
                # the same order as input occurred.
                inds, perm = torch.sort(batch_data['batch'].indices.data)

                # At this point dumped_layers is going to be an array of
                # (num_layers) packed sequences, each of which has (len) x (batch)
                # shape. We would like to transpose this, so that
                # we have an array of "sentences", each of which is
                # an array of "tokens", each of which is an array of "layers",
                # each of which is an array of "neurons".
                dumped_layers = [unpack(layer) for layer in dumped_layers] # Tuples of (tensor, lengths)
                dumped_layers = [
                    [
                        [
                            # Array of layers
                            dumped_layers[i][0][t][idx]
                            for i in range(len(dumped_layers))
                        ]

                        # Array of tokens; dumped_layers[0][1] is the list of
                        # sentence lengths for the batch, so we can look up
                        # number of tokens here
                        for t in range(dumped_layers[0][1][idx])
                    ]
                    # Array of sentences
                    for idx in perm
                ]

                # Accumulate all the dumped layers into one big list of sentences.
                all_dumped_layers.extend(dumped_layers)
            else:
                batch_data = self.translate_batch(batch, data, intervention=intervention)

            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [" ".join(pred)
                                for pred in trans.pred_sents[:self.n_best]]
                all_predictions += [n_best_preds]

                if out_file is None:
                    self.out_file.write('\n'.join(n_best_preds) + '\n')
                    self.out_file.flush()
                else:
                    out_file.write('\n'.join(n_best_preds) + '\n')
                    out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                # Debug attention.
                if attn_debug:
                    srcs = trans.src_raw
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *trans.src_raw) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))

        if self.dump_layers and self.dump_layers != -1:
            torch.save(all_dumped_layers, self.dump_layers)

        elif self.dump_layers == -1:
            return all_dumped_layers, all_scores, all_predictions

        return all_scores, all_predictions
예제 #18
0
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields,
                                           corpus_type, opt):
    """
    Divide src_corpus and tgt_corpus into smaller multiples
    src_copus and tgt corpus files, then build shards, each
    shard will have opt.shard_size samples except last shard.

    The reason we do this is to avoid taking up too much memory due
    to sucking in a huge corpus file.
    """

    src_data = open(src_corpus, "r", encoding="utf-8").readlines()
    tgt_data = open(tgt_corpus, "r", encoding="utf-8").readlines()

    src_corpus = "".join(src_corpus.split(".")[:-1])
    tgt_corpus = "".join(tgt_corpus.split(".")[:-1])

    for x in range(int(len(src_data) / opt.shard_size)):
        open(src_corpus + ".{0}.txt".format(x), "w",
             encoding="utf-8").writelines(src_data[x * opt.shard_size:(x + 1) *
                                                   opt.shard_size])
        open(tgt_corpus + ".{0}.txt".format(x), "w",
             encoding="utf-8").writelines(tgt_data[x * opt.shard_size:(x + 1) *
                                                   opt.shard_size])

    src_list = sorted(glob.glob(src_corpus + '.*.txt'))
    tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt'))

    ret_list = []

    for index, src in enumerate(src_list):
        dataset = inputters.build_dataset(
            fields,
            opt.data_type,
            src_path=src,
            tgt_path=tgt_list[index],
            src_dir=opt.src_dir,
            src_seq_length=opt.src_seq_length,
            tgt_seq_length=opt.tgt_seq_length,
            src_seq_length_trunc=opt.src_seq_length_trunc,
            tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
            dynamic_dict=opt.dynamic_dict,
            sample_rate=opt.sample_rate,
            window_size=opt.window_size,
            window_stride=opt.window_stride,
            window=opt.window,
            image_channel_size=opt.image_channel_size)

        pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index)

        # We save fields in vocab.pt seperately, so make it empty.
        dataset.fields = []

        logger.info(" * saving %sth %s data image shard to %s." %
                    (index, corpus_type, pt_file))
        torch.save(dataset, pt_file)

        ret_list.append(pt_file)

        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return ret_list
예제 #19
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            tgt_path (str): filepath of target data or None
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")

        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            src=src,
            tgt=tgt,
            src_dir=src_dir,
            sample_rate=self.sample_rate,
            window_size=self.window_size,
            window_stride=self.window_stride,
            window=self.window,
            use_filter_pred=self.use_filter_pred,
            image_channel_size=self.image_channel_size,
        )

        cur_device = "cuda" if self.cuda else "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt)

        # Statistics
        counter = count(1)

        all_scores = []
        all_predictions = []

        results = []

        # TODO(daphne): Figure out why putting import at top of the file fails.
        import json
        # Iterating over batches.
        for num, batch in enumerate(data_iter):
            ## Reinitialize previous hypotheses
            self.prev_hyps = []
            inputs = ["" for i in range(batch.batch_size)]
            preds = [[] for i in range(batch.batch_size)]
            scores = [[] for i in range(batch.batch_size)]
            # If doing iterative beam search, may run beam search multiple times.
            for i in range(self.beam_iters):
                batch_data = self.translate_batch(
                    batch,
                    data,
                    attn_debug,
                    builder,
                    fast=self.fast,
                )
                translations = builder.from_batch(batch_data)

                # Iterate over examples in the batch.
                for j, trans in enumerate(translations):
                    pred_scores = list(
                        float(s) for s in trans.pred_scores[:self.n_best])
                    pred_sents = trans.pred_sents[:self.n_best]
                    all_scores += [pred_scores]

                    if 0 in [len(l) for l in pred_sents]:
                        print(
                            'Warning: (batch=%d, translation=%d) generated an empty sequence'
                            % (num, j))

                    if tgt is not None:
                        #TODO(dei): Add back support for this.
                        raise ValueError('tgt not currently supported.')

                    n_best_preds = [" ".join(pred) for pred in pred_sents]
                    all_predictions += [n_best_preds]

                    ## Saves predictions and scores into dictionary
                    ## to be added to final results later

                    inputs[j] = trans.src_raw
                    if self.beam_iters == 1:
                        preds[j] = pred_sents
                        scores[j] = pred_scores
                    else:
                        ## Checks if top candidate is empty (TODO: why is this happening?)
                        k = 0
                        while not trans.pred_sents[k]:
                            k += 1
                        preds[j] += [trans.pred_sents[k]]
                        scores[j] += [float(trans.pred_scores[k])]

                    if self.verbose:
                        sent_number = next(counter)
                        output = trans.log(sent_number)
                        if self.logger:
                            self.logger.info(output)
                        else:
                            os.write(1, output.encode('utf-8'))

                    if attn_debug:
                        preds[j] = trans.pred_sents[0]
                        preds[j].append('</s>')
                        attns = trans.attns[0].tolist()
                        if self.data_type == 'text':
                            srcs = trans.src_raw
                        else:
                            srcs = [str(item) for item in range(len(attns[0]))]
                        header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                        output = header_format.format("", *srcs) + '\n'
                        for word, row in zip(preds, attns):
                            max_index = row.index(max(row))
                            row_format = row_format.replace(
                                "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                            row_format = row_format.replace(
                                "{:*>10.7f} ", "{:>10.7f} ", max_index)
                            output += row_format.format(word, *row) + '\n'
                            row_format = "{:>10.10} " + "{:>10.7f} " * len(
                                srcs)
                        os.write(1, output.encode('utf-8'))

            assert len(inputs) == len(preds) == len(scores)
            for j in range(len(inputs)):
                results.append({
                    'input': inputs[j],
                    'pred': preds[j],
                    'scores': scores[j]
                })

        # Compute overall per-token perplexity.
        pred_score_total = 0
        pred_token_total = 0
        for result in results:
            pred_score_total += sum(result['scores'])
            pred_token_total += sum(len(s) for s in result['pred'])

        try:
            score = pred_score_total / pred_token_total
            ppl = math.exp(-pred_score_total / pred_token_total)
        except Exception as e:
            print(e)
            print(
                'WARNING: SCORE AND PPL WERE COMPUTED DUE TO NUMERICAL ERRORS')
            score = np.nan
            ppl = np.nan

        # Save the results to json.
        json_dump = {'results': results, 'score': score, 'ppl': ppl}
        json.dump(json_dump, self.out_file)
        self.out_file.flush()

        if self.report_score:
            msg = self._report_score('PRED', json_dump['score'], 1)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', json_dump['score'], 1)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            raise ValueError('This code path seems broken.')
            import json
            json.dump(self.beam_accum, codecs.open(self.dump_beam, 'w',
                                                   'utf-8'))
        return all_scores, all_predictions
예제 #20
0
    def decode_sentences(self, sents, cuda=False):
        """
        Takes in a list of sentences and returns a list of sentences
        decode_sentences(['this is fun !', "this is not fun"])
        [('this is fun !', 'I 'm not a this .', -12.412576675415039),
        ('this is not fun', 'I 'm not sure .', -10.160457611083984)]
        :param sents: [str]
        :return: [(src, tgt, log-likelihood-score)]
        """
        unique_filename = str(uuid.uuid4())

        # delete repeating tmp files
        tmp_files = os.listdir(pjoin(self.temp_dir, "l2e"))

        if len(tmp_files) > 10:
            for f_n in tmp_files:
                os.remove(pjoin(self.temp_dir, "l2e", f_n))

        with open(
                pjoin(self.temp_dir, "l2e", '{}.txt'.format(unique_filename)),
                'w') as f:
            for s in sents:
                f.write(s.strip() + '\n')

        data = inputters.build_dataset(
            self.fields,
            src_path=pjoin(self.temp_dir, "l2e",
                           '{}.txt'.format(unique_filename)),
            data_type='text',
            use_filter_pred=False,
            dynamic_dict=False)  # src_seq_length=50, dynamic_dict=False)

        if cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=1,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data,
                                                    self.fields,
                                                    n_best=1,
                                                    replace_unk=True,
                                                    has_tgt=False)

        # this is not really beam-search...
        decoded_sents = []  # (src, tgt, score)

        # we don't keep statistics / scores or anything

        for batch in data_iter:
            batch_data = self.translator.translate_batch(batch,
                                                         data,
                                                         fast=False)
            translations = builder.from_batch(batch_data)

            # going through each sentence in a batch
            for trans in translations:
                n_best_preds = [
                    " ".join(pred)
                    for pred in trans.pred_sents[:self.translator.n_best]
                ]
                for i in range(len(n_best_preds)):
                    decoded_sents.append(
                        (' '.join(trans.src_raw), n_best_preds[i],
                         trans.pred_scores[i].item()))

        return decoded_sents
예제 #21
0
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields,
                                           corpus_type, opt):
    """
    Divide src_corpus and tgt_corpus into smaller multiples
    src_copus and tgt corpus files, then build shards, each
    shard will have opt.shard_size samples except last shard.

    The reason we do this is to avoid taking up too much memory due
    to sucking in a huge corpus file.
    """

    with codecs.open(src_corpus, "r", encoding="utf-8") as fsrc:
        with codecs.open(tgt_corpus, "r", encoding="utf-8") as ftgt:
            logger.info("Reading source and target files: %s %s." %
                        (src_corpus, tgt_corpus))
            src_data = fsrc.readlines()
            tgt_data = ftgt.readlines()

            num_shards = int(len(src_data) / opt.shard_size)
            for x in range(num_shards):
                logger.info("Splitting shard %d." % x)
                f = codecs.open(src_corpus + ".{0}.txt".format(x),
                                "w",
                                encoding="utf-8")
                f.writelines(src_data[x * opt.shard_size:(x + 1) *
                                      opt.shard_size])
                f.close()
                f = codecs.open(tgt_corpus + ".{0}.txt".format(x),
                                "w",
                                encoding="utf-8")
                f.writelines(tgt_data[x * opt.shard_size:(x + 1) *
                                      opt.shard_size])
                f.close()
            num_written = num_shards * opt.shard_size
            if len(src_data) > num_written:
                logger.info("Splitting shard %d." % num_shards)
                f = codecs.open(src_corpus + ".{0}.txt".format(num_shards),
                                'w',
                                encoding="utf-8")
                f.writelines(src_data[num_shards * opt.shard_size:])
                f.close()
                f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards),
                                'w',
                                encoding="utf-8")
                f.writelines(tgt_data[num_shards * opt.shard_size:])
                f.close()

    src_list = sorted(glob.glob(src_corpus + '.*.txt'))
    tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt'))

    ret_list = []

    for index, src in enumerate(src_list):
        logger.info("Building shard %d." % index)
        dataset = inputters.build_dataset(
            fields,
            opt.data_type,
            src_path=src,
            tgt_path=tgt_list[index],
            src_dir=opt.src_dir,
            src_seq_length=opt.src_seq_length,
            tgt_seq_length=opt.tgt_seq_length,
            src_seq_length_trunc=opt.src_seq_length_trunc,
            tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
            dynamic_dict=opt.dynamic_dict,
            sample_rate=opt.sample_rate,
            window_size=opt.window_size,
            window_stride=opt.window_stride,
            window=opt.window,
            image_channel_size=opt.image_channel_size,
            use_filter_pred=False)

        pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index)

        # We save fields in vocab.pt seperately, so make it empty.
        dataset.fields = []
        print("!!!!dataset examples " + str(len(dataset.examples)))
        logger.info(" * saving %sth %s data shard to %s." %
                    (index, corpus_type, pt_file))
        torch.save(dataset, pt_file)

        ret_list.append(pt_file)
        os.remove(src)
        os.remove(tgt_list[index])
        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return ret_list
예제 #22
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  ans_path=None,
                  ans_data_iter=None,
                 ):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(self.fields,
                                       self.data_type,
                                       src_path=src_path,
                                       src_data_iter=src_data_iter,
                                       tgt_path=tgt_path,
                                       tgt_data_iter=tgt_data_iter,
                                       src_dir=src_dir,
                                       sample_rate=self.sample_rate,
                                       window_size=self.window_size,
                                       window_stride=self.window_stride,
                                       window=self.window,
                                       use_filter_pred=self.use_filter_pred,
                                       ans_data_iter=ans_data_iter,
                                       ans_path=ans_path)



        print(data)
        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(
            dataset=data, device=cur_device,
            batch_size=batch_size, train=False, sort=False,
            sort_within_batch=True, shuffle=False)

        for batch in data_iter:
            stats = self.translate_batch(batch, data)
            logger.info(stats)
예제 #23
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        # assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        
        resp_vocab = self.fields["tgt"].vocab

        while True:
            post = input("Type in a post:")
            if post == "exit":
                break
            keyword = input("Type in a keyword:")
            keyword_index = resp_vocab.stoi[keyword]

            seg_lst = jieba.cut(post)
            post = ' '.join(seg_lst)
            src_path = [post]
            data = inputters.build_dataset(self.fields,
                                        self.data_type,
                                        src_path=src_path,
                                        src_data_iter=src_data_iter,
                                        tgt_path=tgt_path,
                                        tgt_data_iter=tgt_data_iter,
                                        src_dir=src_dir,
                                        sample_rate=self.sample_rate,
                                        window_size=self.window_size,
                                        window_stride=self.window_stride,
                                        window=self.window,
                                        use_filter_pred=self.use_filter_pred)

            if self.cuda:
                cur_device = "cuda"
            else:
                cur_device = "cpu"

            data_iter = inputters.OrderedIterator(
                dataset=data, device=cur_device,
                batch_size=batch_size, train=False, sort=False,
                sort_within_batch=True, shuffle=False)

            builder = onmt.translate.TranslationBuilder(
                data, self.fields,
                self.n_best, self.replace_unk, tgt_path)

            # Statistics
            # counter = count(1)
            pred_score_total, pred_words_total = 0, 0
            gold_score_total, gold_words_total = 0, 0

            all_scores = []
            all_predictions = []

            for batch in data_iter:
                # backward反向生成front_seq
                batch_data = self.translate_batch(self.bk_model, batch, data, keyword=keyword_index, fast=self.fast)
                translations = builder.from_batch(batch_data)

                for trans in translations:
                    post = trans.src_raw                 
                    resps = trans.pred_sents
                    # print(resps)
                    scores = [float(np.exp(s)) for s in trans.pred_scores[:self.n_best]]
                    best_index = np.argmax(scores)
                    # resp_front = resps[best_index][::-1][:-1]
                    # resp_front.append(keyword)
                    best_forward_score = 0
                    for resp in resps:
                        resp_front = resp[::-1][:-1]
                        resp_front.append(keyword)
                        resp_front_indexs = [resp_vocab.stoi[w] for w in resp_front]
                        # 将最后一个词替换回已知的keyword
                    
                        # 依据生成的front_seq生成back_seq
                        batch_data = self.translate_batch(self.model, batch, data, front_seq=resp_front_indexs, fast=self.fast)
                        translations = builder.from_batch(batch_data)

                        for trans in translations:
                            resps = trans.pred_sents
                            # print(resps)
                            scores = [float(np.exp(s)) for s in trans.pred_scores[:self.n_best]]
                            if max(scores) > best_forward_score:
                                best_forward_score = max(scores)
                                best_resp_front = resp_front
                                best_index = np.argmax(scores)
                                best_resp_back = resps[best_index][len(resp_front):]

                    resp = ''.join(best_resp_front + best_resp_back)
                    print('response: ', resp)

                    all_scores += [trans.pred_scores[:self.n_best]]
                    pred_score_total += trans.pred_scores[0]
                    pred_words_total += len(trans.pred_sents[0])
                    if tgt_path is not None:
                        gold_score_total += trans.gold_score
                        gold_words_total += len(trans.gold_sent) + 1

                    n_best_preds = [" ".join(pred)
                                    for pred in trans.pred_sents[:self.n_best]]
                    all_predictions += [n_best_preds]
                    '''
                    self.out_file.write('\n'.join(n_best_preds) + '\n')
                    self.out_file.flush()

                    if self.verbose:
                        sent_number = next(counter)
                        output = trans.log(sent_number)
                        if self.logger:
                            self.logger.info(output)
                        else:
                            os.write(1, output.encode('utf-8'))
                    '''

                    # Debug attention.
                    if attn_debug:
                        srcs = trans.src_raw
                        preds = trans.pred_sents[0]
                        preds.append('</s>')
                        attns = trans.attns[0].tolist()
                        header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                        output = header_format.format("", *trans.src_raw) + '\n'
                        for word, row in zip(preds, attns):
                            max_index = row.index(max(row))
                            row_format = row_format.replace(
                                "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                            row_format = row_format.replace(
                                "{:*>10.7f} ", "{:>10.7f} ", max_index)
                            output += row_format.format(word, *row) + '\n'
                            row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                        os.write(1, output.encode('utf-8'))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
예제 #24
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None
        """
        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(self.fields,
                                       self.data_type,
                                       src_path=src_path,
                                       src_data_iter=src_data_iter,
                                       tgt_path=tgt_path,
                                       tgt_data_iter=tgt_data_iter,
                                       src_dir=src_dir,
                                       sample_rate=self.sample_rate,
                                       window_size=self.window_size,
                                       window_stride=self.window_stride,
                                       window=self.window,
                                       use_filter_pred=self.use_filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data, device=self.gpu,
            batch_size=batch_size, train=False, sort=False,
            sort_within_batch=True, shuffle=False)

        builder = onmt.translate.TranslationBuilder(
            data, self.fields,
            self.n_best, self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        for batch in data_iter:
            batch_data = self.translate_batch(batch, data)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[0]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [" ".join(pred)
                                for pred in trans.pred_sents[:self.n_best]]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                # Debug attention.
                if attn_debug:
                    srcs = trans.src_raw
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *trans.src_raw) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores
예제 #25
0
    def scoring(self,
                src_data_path=None,
                src_data_iter=None,
                tgt_data_path=None,
                tgt_data_iter=None,
                batch_size=32):
        if src_data_iter is not None:
            batch_size = len(src_data_iter)
        assert batch_size != 0
        data = inputters.build_dataset(self.fields,
                                       'text',
                                       src_path=src_data_path,
                                       src_data_iter=src_data_iter,
                                       tgt_path=tgt_data_path,
                                       tgt_data_iter=tgt_data_iter,
                                       use_filter_pred=False,
                                       dynamic_dict=False)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        scored_triplets = []
        for batch in data_iter:
            src = inputters.make_features(
                batch, 'src', 'text')  # [src_len, batch_size, num_features]
            _, src_lengths = batch.src

            tgt = inputters.make_features(
                batch, 'tgt', 'text')  # [tgt_len, batch_size, num_features]
            _, tgt_lengths = batch.tgt

            logits, probs = self.model(src, tgt, src_lengths, tgt_lengths)

            # Sorting
            inds, perm = torch.sort(batch.indices.data)

            # orig_src = batch.src[0].data.index_select(1, perm)
            # orig_tgt = batch.tgt[0].data.index_select(1, perm)
            orig_probs = probs.index_select(0, perm)

            for b in range(batch.batch_size):
                src_raw = data.examples[inds[b]].src
                tgt_raw = data.examples[inds[b]].tgt
                final_score = orig_probs[b].data.item()
                scored_triplets.append({
                    'src': src_raw,
                    'tgt': tgt_raw,
                    'score': final_score
                })
                # if final_score > 0.5:
                #     print('=' * 30)
                #     print('src: {}'.format(' '.join(src_raw)))
                #     print('tgt: {}; score: {}'.format(' '.join(tgt_raw), final_score))
                #     print('=' * 30)

        return scored_triplets
예제 #26
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None
        """
        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(self.fields,
                                       self.data_type,
                                       src_path=src_path,
                                       src_data_iter=src_data_iter,
                                       tgt_path=tgt_path,
                                       tgt_data_iter=tgt_data_iter,
                                       src_dir=src_dir,
                                       sample_rate=self.sample_rate,
                                       window_size=self.window_size,
                                       window_stride=self.window_stride,
                                       window=self.window,
                                       use_filter_pred=self.use_filter_pred)

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=self.gpu,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt_path)

        # ADDED --------------------------------------------------------------
        # Load the translation pieces list
        #home_path = "/home/pmlf/Documents/github/OpenNMT-py-fork/"
        home_path = "/home/ubuntu/OpenNMT-py-fork/"
        tp_path = home_path + "extra_data/translation_pieces_md_10-th0pt5.pickle"
        translation_pieces = pickle.load(open(tp_path, 'rb'))
        tot_time = 0
        # END ----------------------------------------------------------------

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        for ix, batch in enumerate(data_iter):
            # ADDED --------------------------------------------------------------
            start_time = time.time()
            # END ----------------------------------------------------------------
            batch_data = self.translate_batch(batch, data, translation_pieces)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[0]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                # Debug attention.
                if attn_debug:
                    srcs = trans.src_raw
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *trans.src_raw) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

            # ADDED --------------------------------------------------------------
            duration = time.time() - start_time
            tot_time += duration
            tot_time_print = str(
                time.strftime("%H:%M:%S", time.gmtime(tot_time)))
            print("Batch {} - Duration: {:.2f} - Total: {}".format(
                ix, duration, tot_time_print))
            # END ----------------------------------------------------------------

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores
예제 #27
0
    def translate(self,
                  knl,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            tgt_path (str): filepath of target data or None
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src is not None
        assert knl is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")

        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            knl=knl,
            src=src,
            tgt=tgt,
            knl_seq_length_trunc=200,
            src_seq_length_trunc=50,
            src_dir=src_dir,
            sample_rate=self.sample_rate,
            window_size=self.window_size,
            window_stride=self.window_stride,
            window=self.window,
            use_filter_pred=self.use_filter_pred,
            image_channel_size=self.image_channel_size,
            dynamic_dict=self.copy_attn)

        cur_device = "cuda" if self.cuda else "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        for batch in data_iter:
            batch_data = self.translate_batch(batch,
                                              data,
                                              attn_debug,
                                              fast=self.fast)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
예제 #28
0
def translate(opt):
    out_file = codecs.open(opt.output, 'w+', 'utf-8')

    if opt.gpu > -1:
        torch.cuda.set_device(opt.gpu)

    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    fields, model, model_opt = \
        onmt.model_builder.load_test_model(opt, dummy_opt.__dict__)

    data = inputters.build_dataset(fields,
                                   'text',
                                   src_path=opt.src,
                                   src_data_iter=None,
                                   tgt_path=opt.tgt,
                                   tgt_data_iter=None,
                                   src_dir=opt.src_dir,
                                   sample_rate='16000',
                                   window_size=.02,
                                   window_stride=.01,
                                   window='hamming',
                                   use_filter_pred=False)

    device = torch.device('cuda' if opt.gpu > -1 else 'cpu')

    batch_size = 1

    data_iter = inputters.OrderedIterator(dataset=data,
                                          device=device,
                                          batch_size=batch_size,
                                          train=False,
                                          sort=False,
                                          sort_within_batch=True,
                                          shuffle=False)

    pair_size = model_opt.wpe_pair_size

    s_id = fields["tgt"].vocab.stoi['<s>']
    if '<sgo>' in fields["tgt"].vocab.stoi:
        ss_id = fields["tgt"].vocab.stoi['<sgo>']
    else:
        ss_id = fields['tgt'].vocab.stoi['<unk>']
    if '<seos>' in fields['tgt'].vocab.stoi:
        eos_id = fields['tgt'].vocab.stoi['<seos>']
    else:
        eos_id = fields['tgt'].vocab.stoi['</s>']

    for i, batch in enumerate(data_iter):
        tgt = torch.LongTensor([s_id] * batch_size + [ss_id] *
                               ((pair_size - 1) * batch_size)).view(
                                   pair_size,
                                   batch_size).unsqueeze(2).to(device)
        dec_state = None
        src = inputters.make_features(batch, 'src', 'text')
        _, src_lengths = batch.src

        result = None

        for _ in range(opt.max_length):
            outputs, _, dec_state = model(src, tgt, src_lengths, dec_state)
            scores = model.generator(outputs.view(-1, outputs.size(2)))
            indices = scores.argmax(dim=1)
            tgt = indices.view(pair_size, batch_size,
                               1)  # (pair_size x batch x feat)

            assert batch_size == 1
            if tgt[0][0][0].item() == eos_id:
                break

            if result is None:
                result = indices.view(pair_size, batch_size)
            else:
                result = torch.cat(
                    [result, indices.view(pair_size, batch_size)], 0)

        result = result.transpose(0, 1).tolist()
        for sent in result:
            sent = [fields["tgt"].vocab.itos[_] for _ in sent]
            sent = [_ for _ in sent if _ not in ['<blank>', '<seos>', '</s>']]
            sent = ' '.join(sent)
            out_file.write(sent + '\n')

        print('Translated {} batches'.format(i))

    out_file.close()
예제 #29
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  rk_path=None,
                  rk_data_iter=None,
                  key_indicator_path=None,
                  key_indicator_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            rk_path (str): filepath of retrieved keyphrases
            rk_data_iter (iterator): an interator generating retrieved keyphrases
                e.g. it may be a list or an openned file
            key_indicator_path (str): filepath of src keyword indicators
            key_indicator_iter (iterator): an interator generating src keyword indicators
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(self.fields,
                                       self.data_type,
                                       src_path=src_path,
                                       src_data_iter=src_data_iter,
                                       rk_path=rk_path,
                                       rk_data_iter=rk_data_iter,
                                       key_indicator_path=key_indicator_path,
                                       key_indicator_iter=key_indicator_iter,
                                       tgt_path=tgt_path,
                                       tgt_data_iter=tgt_data_iter,
                                       src_dir=src_dir,
                                       sample_rate=self.sample_rate,
                                       window_size=self.window_size,
                                       window_stride=self.window_stride,
                                       window=self.window,
                                       use_filter_pred=self.use_filter_pred)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(
            dataset=data, device=cur_device,
            batch_size=batch_size, train=False, sort=False,
            sort_within_batch=True, shuffle=False)

        builder = onmt.translate.TranslationBuilder(
            data, self.fields,
            self.n_best, self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data, fast=self.fast)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [" ".join(pred)
                                for pred in trans.pred_sents[:self.n_best]]

                n_best_preds_scores = [round(sc.exp().item(), 5) for sc in trans.pred_scores[:self.n_best]]

                all_predictions += [n_best_preds]
                self.out_file.write(' ; '.join(n_best_preds) + '\n')
                self.out_file.flush()
                self.scores_out_file.write(' ; '.join([str(sc) for sc in n_best_preds_scores]) + '\n')

                if trans.selector_probs is not None:
                    selector_probs = trans.selector_probs.tolist()
                    selector_probs = [round(sp, 5) for sp in selector_probs if sp != 0.0]
                    self.sel_probs_out_file.write(' ; '.join([str(sp) for sp in selector_probs]) + '\n')

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                # Debug attention.
                if attn_debug:
                    srcs = trans.src_raw
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *trans.src_raw) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))

        if self.opt is not None:
            evaluate_func(opts=self.opt, do_stem=True)
        return all_scores, all_predictions
예제 #30
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False,
                  node_type_seq=None,
                  atc=None):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None
        assert node_type_seq is not None, 'Node Types must be provided'
        node_type_scores = node_type_seq[1]
        node_type_seq = node_type_seq[0]
        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(self.fields,
                                       self.data_type,
                                       src_path=src_path,
                                       src_data_iter=src_data_iter,
                                       tgt_path=tgt_path,
                                       tgt_data_iter=tgt_data_iter,
                                       src_dir=src_dir,
                                       sample_rate=self.sample_rate,
                                       window_size=self.window_size,
                                       window_stride=self.window_stride,
                                       window=self.window,
                                       use_filter_pred=self.use_filter_pred)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        #debug(self.option.tree_count)

        def check_correctness(preds, gold):
            for p in preds:
                if p.strip() == gold.strip():
                    return 1
            return 0

        total_correct = 0

        for bidx, batch in enumerate(data_iter):
            # if bidx == 100:
            #     break
            example_idx = batch.indices.item(
            )  # Only 1 item in this batch, guaranteed
            # if bidx % 20 == 0:
            if bidx % 20 == 0:
                debug('Current Example : ', example_idx)
            nt_sequences = node_type_seq[example_idx]
            nt_scores = node_type_scores[example_idx]
            if atc is not None:
                atc_item = atc[example_idx]
            else:
                atc_item = None
            scores = []
            predictions = []
            tree_count = self.option.tree_count
            for type_sequence, type_score in zip(nt_sequences[:tree_count],
                                                 nt_scores[:tree_count]):
                batch_data = self.translate_batch(batch,
                                                  data,
                                                  node_type_str=type_sequence,
                                                  fast=self.fast,
                                                  atc=atc_item)
                translations = builder.from_batch(batch_data)
                already_found = False
                for trans in translations:
                    pred_scores = [
                        score + type_score
                        for score in trans.pred_scores[:self.n_best]
                    ]
                    # debug(len(pred_scores))
                    scores += pred_scores

                    pred_score_total += trans.pred_scores[0]
                    pred_words_total += len(trans.pred_sents[0])
                    if tgt_path is not None:
                        gold_score_total += trans.gold_score
                        gold_words_total += len(trans.gold_sent) + 1

                    n_best_preds = [
                        " ".join(pred)
                        for pred in trans.pred_sents[:self.n_best]
                    ]
                    gold_sent = ' '.join(trans.gold_sent)
                    correct = check_correctness(n_best_preds, gold_sent)
                    # debug(correct == 1)
                    if not already_found:
                        total_correct += correct
                        already_found = True
                    # debug(len(n_best_preds))
                    predictions += n_best_preds

                    if self.verbose:
                        sent_number = next(counter)
                        output = trans.log(sent_number)
                        if self.logger:
                            self.logger.info(output)
                        else:
                            os.write(1, output.encode('utf-8'))

                    if attn_debug:
                        srcs = trans.src_raw
                        preds = trans.pred_sents[0]
                        preds.append('</s>')
                        attns = trans.attns[0].tolist()
                        header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                        output = header_format.format("", *
                                                      trans.src_raw) + '\n'
                        for word, row in zip(preds, attns):
                            max_index = row.index(max(row))
                            row_format = row_format.replace(
                                "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                            row_format = row_format.replace(
                                "{:*>10.7f} ", "{:>10.7f} ", max_index)
                            output += row_format.format(word, *row) + '\n'
                            row_format = "{:>10.10} " + "{:>10.7f} " * len(
                                srcs)
                        os.write(1, output.encode('utf-8'))
            all_scores += [scores]
            all_predictions += [predictions]

        if self.report_score:
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        #debug(total_correct)
        return all_scores, all_predictions