예제 #1
0
def _load_fields(dataset, data_type, opt, checkpoint):
    if checkpoint is not None:
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        fields = load_fields_from_vocab(checkpoint['vocab'], data_type)
    else:
        with open(opt.data + '.vocab.pt', 'rb') as f:
            voc = pickle.load(f)
        # voc = torch.load(opt.data + '.vocab.pt')
        fields = load_fields_from_vocab(voc, data_type)

        def my_func(batch):
            print('my_func')
            rawstrs = myutils.recover_to_raw(batch[0])
            g_dataset = myutils.MolData(rawstrs, batch[1].tolist())
            g_loader = torch.utils.data.DataLoader(
                g_dataset,
                batch_size=len(rawstrs),
                shuffle=False,
                collate_fn=myutils.collate_dgl)
            return batch, g_loader

    # fields['src'].custom_func = my_func

    fields = dict([(k, f) for (k, f) in fields.items()
                   if k in dataset.examples[0].__dict__])
    myutils.add_more_field(fields)
    if data_type == 'text':
        logger.info(' * vocabulary size. source = %d; target = %d' %
                    (len(fields['src'].vocab), len(fields['tgt'].vocab)))
    else:
        logger.info(' * vocabulary size. target = %d' %
                    (len(fields['tgt'].vocab)))

    return fields
예제 #2
0
def main():
    opt = parse_args()

    if (opt.max_shard_size > 0):
        raise AssertionError("-max_shard_size is deprecated, please use \
                             -shard_size (number of examples) instead.")

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    # 下面的代码是尝试解决多进程prepare失败的问题,但是没有效果
    torch.multiprocessing.set_sharing_strategy('file_system')
    import resource
    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (65535, rlimit[1]))
    # END

    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)
    myutils.add_more_field(fields)
    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #3
0
def build_translator(opt,
                     report_score=True,
                     logger=None,
                     out_file=None,
                     log_probs_out_file=None):
    if out_file is None:
        out_file = codecs.open(opt.output, 'w+', 'utf-8')

        if opt.log_probs:
            log_probs_out_file = codecs.open(opt.output + '_log_probs', 'w+',
                                             'utf-8')

    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    if len(opt.models) > 1:
        # use ensemble decoding if more than one model is specified
        fields, model, model_opt = \
            onmt.decoders.ensemble.load_test_model(opt, dummy_opt.__dict__)
    else:
        fields, model, model_opt = \
            onmt.model_builder.load_test_model(opt, dummy_opt.__dict__)

    scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta,
                                             opt.coverage_penalty,
                                             opt.length_penalty)

    kwargs = {
        k: getattr(opt, k)
        for k in [
            "beam_size", "n_best", "max_length", "min_length",
            "stepwise_penalty", "block_ngram_repeat", "ignore_when_blocking",
            "dump_beam", "report_bleu", "data_type", "replace_unk", "gpu",
            "verbose", "fast", "sample_rate", "window_size", "window_stride",
            "window", "image_channel_size", "mask_from"
        ]
    }
    myutils.add_more_field(fields)
    translator = Translator(model,
                            fields,
                            global_scorer=scorer,
                            out_file=out_file,
                            report_score=report_score,
                            copy_attn=model_opt.copy_attn,
                            logger=logger,
                            log_probs_out_file=log_probs_out_file,
                            **kwargs)
    return translator
예제 #4
0
def load_fields_from_vocab(vocab, data_type="text"):
    """
    Load Field objects from `vocab.pt` file.
    """

    vocab = dict(vocab)
    n_src_features = len(collect_features(vocab, 'src'))
    n_tgt_features = len(collect_features(vocab, 'tgt'))
    fields = get_fields(data_type, n_src_features, n_tgt_features)
    for k, v in vocab.items():
        # Hack. Can't pickle defaultdict :(
        v.stoi = defaultdict(lambda: 0, v.stoi)
        fields[k].vocab = v
    myutils.add_more_field(fields)
    return fields
예제 #5
0
def build_dataset(fields,
                  data_type,
                  src_data_iter=None,
                  src_path=None,
                  src_dir=None,
                  tgt_data_iter=None,
                  tgt_path=None,
                  src_seq_length=0,
                  tgt_seq_length=0,
                  src_seq_length_trunc=0,
                  tgt_seq_length_trunc=0,
                  dynamic_dict=True,
                  sample_rate=0,
                  window_size=0,
                  window_stride=0,
                  window=None,
                  normalize_audio=True,
                  use_filter_pred=True,
                  image_channel_size=3):
    """
    Build src/tgt examples iterator from corpus files, also extract
    number of features.
    """
    def _make_examples_nfeats_tpl(data_type,
                                  src_data_iter,
                                  src_path,
                                  src_dir,
                                  src_seq_length_trunc,
                                  sample_rate,
                                  window_size,
                                  window_stride,
                                  window,
                                  normalize_audio,
                                  image_channel_size=3):
        """
        Process the corpus into (example_dict iterator, num_feats) tuple
        on source side for different 'data_type'.
        """

        if data_type == 'text':
            src_examples_iter, num_src_feats = \
                TextDataset.make_text_examples_nfeats_tpl(
                    src_data_iter, src_path, src_seq_length_trunc, "src")

        elif data_type == 'img':
            src_examples_iter, num_src_feats = \
                ImageDataset.make_image_examples_nfeats_tpl(
                    src_data_iter, src_path, src_dir, image_channel_size)

        elif data_type == 'audio':
            if src_data_iter:
                raise ValueError("""Data iterator for AudioDataset isn't
                                    implemented""")

            if src_path is None:
                raise ValueError("AudioDataset requires a non None path")
            src_examples_iter, num_src_feats = \
                AudioDataset.make_audio_examples_nfeats_tpl(
                    src_path, src_dir, sample_rate,
                    window_size, window_stride, window,
                    normalize_audio)

        return src_examples_iter, num_src_feats

    src_examples_iter, num_src_feats = \
        _make_examples_nfeats_tpl(data_type, src_data_iter, src_path, src_dir,
                                  src_seq_length_trunc, sample_rate,
                                  window_size, window_stride,
                                  window, normalize_audio,
                                  image_channel_size=image_channel_size)

    # For all data types, the tgt side corpus is in form of text.
    tgt_examples_iter, num_tgt_feats = \
        TextDataset.make_text_examples_nfeats_tpl(
            tgt_data_iter, tgt_path, tgt_seq_length_trunc, "tgt")

    if data_type == 'text':
        dataset = TextDataset(fields,
                              src_examples_iter,
                              tgt_examples_iter,
                              num_src_feats,
                              num_tgt_feats,
                              src_seq_length=src_seq_length,
                              tgt_seq_length=tgt_seq_length,
                              dynamic_dict=dynamic_dict,
                              use_filter_pred=use_filter_pred)

    elif data_type == 'img':
        dataset = ImageDataset(fields,
                               src_examples_iter,
                               tgt_examples_iter,
                               num_src_feats,
                               num_tgt_feats,
                               tgt_seq_length=tgt_seq_length,
                               use_filter_pred=use_filter_pred,
                               image_channel_size=image_channel_size)

    elif data_type == 'audio':
        dataset = AudioDataset(fields,
                               src_examples_iter,
                               tgt_examples_iter,
                               tgt_seq_length=tgt_seq_length,
                               use_filter_pred=use_filter_pred)
    # add graph fields in here hhhhh
    myutils.add_more_field(dataset.fields)

    return dataset