Beispiel #1
0
def build_save_dataset(corpus_type, fields, src_corpus, tgt_corpus, savepath,
                       args):
    """ Building and saving the dataset """
    assert corpus_type in ['train', 'dev', 'test']
    dataset = inputters.build_dataset(fields,
                                      data_type='text',
                                      src_path=src_corpus,
                                      tgt_path=tgt_corpus,
                                      src_dir='',
                                      src_seq_length=args.max_src_len,
                                      tgt_seq_length=args.max_tgt_len,
                                      src_seq_length_trunc=0,
                                      tgt_seq_length_trunc=0,
                                      dynamic_dict=True)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    for i in range(len(dataset)):
        if i % 500 == 0:
            print(i)
        setattr(dataset.examples[i], 'graph',
                myutils.str2graph(dataset.examples[i].src))

    pt_file = "{:s}/{:s}.pt".format(savepath, corpus_type)
    # torch.save(dataset, pt_file)
    with open(pt_file, 'wb') as f:
        pickle.dump(dataset, f)
    return [pt_file]
Beispiel #2
0
def parrel_func(example, have_3d):
    # myutils.str2graph(dataset.examples[i].src)
    graph = None
    try:
        smile = example.src
        graph = myutils.str2graph(smile, have_3d)
    except Exception as e:
        print(e)
        # raise e
    setattr(example, 'graph', graph)
    return example
Beispiel #3
0
def build_save_dataset(corpus_type, fields, opt):
    """ Building and saving the dataset """
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        src_corpus = opt.train_src
        tgt_corpus = opt.train_tgt
    else:
        src_corpus = opt.valid_src
        tgt_corpus = opt.valid_tgt

    if (opt.shard_size > 0):
        return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus,
                                                      fields, corpus_type, opt)

    # For data_type == 'img' or 'audio', currently we don't do
    # preprocess sharding. We only build a monolithic dataset.
    # But since the interfaces are uniform, it would be not hard
    # to do this should users need this feature.
    dataset = inputters.build_dataset(
        fields,
        "text",
        src_path=src_corpus,
        tgt_path=tgt_corpus,
        src_dir=opt.src_dir,
        src_seq_length=opt.src_seq_length,
        tgt_seq_length=opt.tgt_seq_length,
        src_seq_length_trunc=opt.src_seq_length_trunc,
        tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
        dynamic_dict=opt.dynamic_dict,
        sample_rate=opt.sample_rate,
        window_size=opt.window_size,
        window_stride=opt.window_stride,
        window=opt.window,
        image_channel_size=opt.image_channel_size)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type)
    logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file))

    for i in range(len(dataset)):
        if i % 500 == 0:
            print(i)
        setattr(dataset.examples[i], 'graph',
                myutils.str2graph(dataset.examples[i].src))

    # torch.save(dataset, pt_file)
    with open(pt_file, 'wb') as f:
        pickle.dump(dataset, f)
    return [pt_file]
Beispiel #4
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.
        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None
        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
        Returns:
            (`list`, `list`)
            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")

        data = inputters. \
            build_dataset(self.fields,
                          self.data_type,
                          src_path=src_path,
                          src_data_iter=src_data_iter,
                          tgt_path=tgt_path,
                          tgt_data_iter=tgt_data_iter,
                          src_dir=src_dir,
                          sample_rate=self.sample_rate,
                          window_size=self.window_size,
                          window_stride=self.window_stride,
                          window=self.window,
                          use_filter_pred=self.use_filter_pred,
                          image_channel_size=self.image_channel_size)

        # add the graph field
        for i in range(len(data)):
            if i % 500 == 0:
                print(i)
            setattr(data.examples[i], 'graph',
                    myutils.str2graph(data.examples[i].src))

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        for batch in data_iter:
            batch.graph = myutils.pad_for_graph(batch.graph,
                                                torch.max(batch.src[1]).item())

            batch_data = self.translate_batch(batch, data, fast=self.fast)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.log_probs_out_file is not None:
                    self.log_probs_out_file.write('\n'.join([
                        str(t.item()) for t in trans.pred_scores[:self.n_best]
                    ]) + '\n')
                    self.log_probs_out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                # Debug attention.
                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Beispiel #5
0
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields,
                                           corpus_type, opt):
    """
    Divide src_corpus and tgt_corpus into smaller multiples
    src_copus and tgt corpus files, then build shards, each
    shard will have opt.shard_size samples except last shard.

    The reason we do this is to avoid taking up too much memory due
    to sucking in a huge corpus file.
    """

    with codecs.open(src_corpus, "r", encoding="utf-8") as fsrc:
        with codecs.open(tgt_corpus, "r", encoding="utf-8") as ftgt:
            logger.info("Reading source and target files: %s %s." %
                        (src_corpus, tgt_corpus))
            src_data = fsrc.readlines()
            tgt_data = ftgt.readlines()

            num_shards = int(len(src_data) / opt.shard_size)
            for x in range(num_shards):
                logger.info("Splitting shard %d." % x)
                f = codecs.open(src_corpus + ".{0}.txt".format(x),
                                "w",
                                encoding="utf-8")
                f.writelines(src_data[x * opt.shard_size:(x + 1) *
                                      opt.shard_size])
                f.close()
                f = codecs.open(tgt_corpus + ".{0}.txt".format(x),
                                "w",
                                encoding="utf-8")
                f.writelines(tgt_data[x * opt.shard_size:(x + 1) *
                                      opt.shard_size])
                f.close()
            num_written = num_shards * opt.shard_size
            if len(src_data) > num_written:
                logger.info("Splitting shard %d." % num_shards)
                f = codecs.open(src_corpus + ".{0}.txt".format(num_shards),
                                'w',
                                encoding="utf-8")
                f.writelines(src_data[num_shards * opt.shard_size:])
                f.close()
                f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards),
                                'w',
                                encoding="utf-8")
                f.writelines(tgt_data[num_shards * opt.shard_size:])
                f.close()

    src_list = sorted(glob.glob(src_corpus + '.*.txt'))
    tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt'))

    ret_list = []

    for index, src in enumerate(src_list):
        logger.info("Building shard %d." % index)
        dataset = inputters.build_dataset(
            fields,
            opt.data_type,
            src_path=src,
            tgt_path=tgt_list[index],
            src_dir=opt.src_dir,
            src_seq_length=opt.src_seq_length,
            tgt_seq_length=opt.tgt_seq_length,
            src_seq_length_trunc=opt.src_seq_length_trunc,
            tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
            dynamic_dict=opt.dynamic_dict,
            sample_rate=opt.sample_rate,
            window_size=opt.window_size,
            window_stride=opt.window_stride,
            window=opt.window,
            image_channel_size=opt.image_channel_size)

        pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index)

        # We save fields in vocab.pt seperately, so make it empty.
        dataset.fields = []

        logger.info(" * saving %sth %s data shard to %s." %
                    (index, corpus_type, pt_file))
        for i in range(len(dataset)):
            if i % 500 == 0:
                print(i)
            setattr(dataset.examples[i], 'graph',
                    myutils.str2graph(dataset.examples[i].src))

        # torch.save(dataset, pt_file)
        with open(pt_file, 'wb') as f:
            pickle.dump(dataset, f)

        ret_list.append(pt_file)
        os.remove(src)
        os.remove(tgt_list[index])
        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return ret_list
Beispiel #6
0
import torch
import dgl
import pickle
import torchtext

import onmt.myutils as myutils
# dataset = torch.load('data/seqdata.train.0.pt')
# g = myutils.str2graph(dataset.examples[0].src)
# setattr(dataset.examples[0], 'graph', g)
# print(g)
# torch.save(dataset,'g.pt')
# a = torch.load('g.pt')
# print(a.examples[0].graph)
# print(a)
dataset = torch.load('data/seqdata.train.0.pt')
g = myutils.str2graph(dataset.examples[0].src)
setattr(dataset.examples[0], 'graph', g)
with open('g.pt', 'wb') as f:
    pickle.dump(dataset, f)
with open('g.pt', 'rb') as f:
    a = pickle.load(f)

print(a)