示例#1
0
 def __init__(
         self,
         md_help=True,
         config_file_parser_class=cfargparse.YAMLConfigFileParser,
         formatter_class=cfargparse.ArgumentDefaultsHelpFormatter,
         **kwargs):
     super(ArgumentParser, self).__init__(
         config_file_parser_class=config_file_parser_class,
         formatter_class=formatter_class,
         **kwargs)
     if md_help:
         opts.add_md_help_argument(self)
示例#2
0
def parse_args():
    """ Parsing arguments """
    parser = argparse.ArgumentParser(
        description='preprocess.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    opts.add_md_help_argument(parser)
    opts.preprocess_opts(parser)

    opt = parser.parse_args()
    torch.manual_seed(opt.seed)

    return opt
def parse_args():
    parser = argparse.ArgumentParser(
        description='preprocess.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    opts.add_md_help_argument(parser)
    opts.preprocess_opts(parser)

    opt = parser.parse_args()
    torch.manual_seed(opt.seed)

    check_existing_pt_files(opt)

    return opt
示例#4
0
def parse_args():
    parser = configargparse.ArgumentParser(
        description='train.py',
        config_file_parser_class=configargparse.YAMLConfigFileParser,
        formatter_class=configargparse.ArgumentDefaultsHelpFormatter)

    opts.general_opts(parser)
    opts.config_opts(parser)
    opts.add_md_help_argument(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)

    opt = parser.parse_args()

    return opt
示例#5
0
def parse_args():
    """ Parsing arguments """
    parser = configargparse.ArgumentParser(
        description='preprocess.py',
        config_file_parser_class=configargparse.YAMLConfigFileParser,
        formatter_class=configargparse.ArgumentDefaultsHelpFormatter)

    opts.config_opts(parser)
    opts.add_md_help_argument(parser)
    opts.preprocess_opts(parser)

    opt = parser.parse_args()
    torch.manual_seed(opt.seed)

    check_existing_pt_files(opt)

    return opt
示例#6
0
def parse_args():
    """ Parsing arguments """
    parser = argparse.ArgumentParser(
        description='preprocess.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # parser.add_argument("")
    #  group.add_argument('-train_dir', required=True, default='data/race_train.json',
    #                     help="Path to the training data")
    #  group.add_argument('-valid_dir', required=True, default='data/race_dev.json',
    #                     help="Path to the validation data")
    #  group.add_argument('-data_type', choices=["text"], help="""text""")
    #  group.add_argument('-save_data', required=True, default='data/processed',
    #                     help="Output file for the prepared data")

    opts.add_md_help_argument(parser)
    opts.preprocess_opts(parser)
    #print (parser.parse_args())
    opt = parser.parse_args()
    torch.manual_seed(opt.seed)

    return opt
示例#7
0
    def __init__(self, gpu):
        parser = configargparse.ArgumentParser(
            description='translate.py',
            config_file_parser_class=configargparse.YAMLConfigFileParser,
            formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
        opts.config_opts(parser)
        opts.add_md_help_argument(parser)
        opts.translate_opts(parser)
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        if len(gpu) > 1:
            print('do not try hacking')
            exit()
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

        sys.argv = ["python"]
        opt = parser.parse_args()

        opt.models = [
            './modules/multi_summ/dataset_m2s2/korean_bert_8_single_new_economy_add_cls_sep_segment_eos_penalty_new_bloom_step_25000.pt'
        ]
        opt.segment = True
        opt.batch_size = 8
        opt.beam_size = 10
        opt.src = '.1'
        opt.output = '.1'
        opt.verbose = True
        opt.stepwise_penalty = True
        opt.coverage_penalty = 'sumarry'
        opt.beta = 5
        opt.length_penalty = 'wu'
        opt.alpha = 0.9
        opt.block_ngram_repeat = 3
        opt.ignore_when_blocking = [".", "</t", "<t>", ",_", "%"]
        opt.max_length = 300
        opt.min_length = 35
        opt.gpu = 0
        opt.segment = True

        logger = init_logger(opt.log_file)
        self.translator = build_translator(opt, report_score=True)
示例#8
0
def parse_args():
    """ Parsing arguments """
    parser = argparse.ArgumentParser(
        description='preprocess.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    opts.add_md_help_argument(parser)
    opts.preprocess_opts(parser)
    parser.add_argument('-parrel_run',
                        action='store_true',
                        default=False,
                        help='生成靶标')
    parser.add_argument('-with_3d_confomer',
                        action='store_true',
                        default=False,
                        help='原子特征是否在最后3个维度加上坐标')
    opt = parser.parse_args()
    torch.manual_seed(opt.seed)

    check_existing_pt_files(opt)

    return opt
示例#9
0
def class_weight(class_probs, e):
    ppos = class_probs['ppos']
    pneg = class_probs['pneg']

    return {
        'wpos': (1 - e) * 0.5 + e * (1 - ppos),
        'wneg': (1 - e) * 0.5 + e * (1 - pneg)
    }


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='train.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    opts.add_md_help_argument(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    opt = parser.parse_args()
    with open('%s.arg' % opt.exp, 'w') as f:
        f.write(' '.join(sys.argv[1:]))

    TEXT, LALEBL, train_iter, valid_iter = \
        iters.build_iters_lm(ftrain=opt.ftrain, fvalid=opt.fvalid,
                             bsz=opt.batch_size, level=opt.level,
                             min_freq=opt.min_freq)

    class_probs = dataset_bias(train_iter)
    print('Class probs: ', class_probs)
    # cweights = class_weight(class_probs, opt.label_smoothing)
    cweights = {'wneg': 1 - opt.pos_weight, 'wpos': opt.pos_weight}
示例#10
0
import onmt.opts as opts
from train_multi import main as multi_main
from train_single import main as single_main


def main(opt):

    if opt.rnn_type == "SRU" and not opt.gpuid:
        raise AssertionError("Using SRU requires -gpuid set.")

    if torch.cuda.is_available() and not opt.gpuid:
        print("WARNING: You have a CUDA device, should run with -gpuid 0")

    if len(opt.gpuid) > 1:
        multi_main(opt)
    else:
        single_main(opt)


if __name__ == "__main__":
    PARSER = argparse.ArgumentParser(
        description='train.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    opts.add_md_help_argument(PARSER)
    opts.model_opts(PARSER)
    opts.train_opts(PARSER)

    OPT = PARSER.parse_args()
    main(OPT)
def main():

    parser = argparse.ArgumentParser(description='embeddings_to_torch.py')
    group = parser.add_argument_group('Embedding')
    group.add_argument('-emb_file_enc', default = '/data/home/shuaipengju/data/glove.840B.300d.txt',#required=True,
                        help="source Embeddings from this file")
    group.add_argument('-emb_file_dec', default = '/data/home/shuaipengju/data/glove.840B.300d.txt',#required=True,
                        help="target Embeddings from this file")
    group.add_argument('-output_file', default = '/data/home/shuaipengju/distractor_code/data/processed.glove',#required=True,
                        help="Output file for the prepared data")
    group.add_argument('-dict_file',  default = '/data/home/shuaipengju/distractor_code/data/processed.vocab.pt',#required=True,
                        help="Dictionary file")
    group.add_argument('-verbose', action="store_true", default=False)
    group.add_argument('-skip_lines', type=int, default=0,
                        help="Skip first lines of the embedding file")
    group.add_argument('-type', choices=TYPES, default="GloVe")
    opts.add_md_help_argument(parser)

    # parser.add_argument('-emb_file_enc', default='/data/home/shuaipengju/data/glove.840B.300d.txt',  # required=True,
    #                    help="source Embeddings from this file")
    # parser.add_argument('-emb_file_dec', default='/data/home/shuaipengju/data/glove.840B.300d.txt',  # required=True,
    #                    help="target Embeddings from this file")
    # parser.add_argument('-output_file', default='/data/home/shuaipengju/distractor_code/data/processed.glove',
    #                    # required=True,
    #                    help="Output file for the prepared data")
    # parser.add_argument('-dict_file', default='/data/home/shuaipengju/distractor_code/data/processed.vocab.pt',
    #                    # required=True,
    #                    help="Dictionary file")
    # parser.add_argument('-verbose', action="store_true", default=False)
    # parser.add_argument('-skip_lines', type=int, default=0,
    #                    help="Skip first lines of the embedding file")
    # parser.add_argument('-type', choices=TYPES, default="GloVe")
    opt = parser.parse_args()

    enc_vocab, dec_vocab = get_vocabs(opt.dict_file)
    if opt.type == "word2vec":
        opt.skip_lines = 1

    embeddings_enc = get_embeddings(opt.emb_file_enc, opt, flag='enc')
    embeddings_dec = get_embeddings(opt.emb_file_dec, opt, flag='dec')


    filtered_enc_embeddings, enc_count = match_embeddings(enc_vocab,
                                                          embeddings_enc,
                                                          opt)
    filtered_dec_embeddings, dec_count = match_embeddings(dec_vocab,
                                                          embeddings_dec,
                                                          opt)
    logger.info("\nMatching: ")
    match_percent = [_['match'] / (_['match'] + _['miss']) * 100
                     for _ in [enc_count, dec_count]]
    logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
                % (enc_count['match'],
                   enc_count['miss'],
                   match_percent[0]))
    logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
                % (dec_count['match'],
                   dec_count['miss'],
                   match_percent[1]))

    logger.info("\nFiltered embeddings:")
    logger.info("\t* enc: %s" % str(filtered_enc_embeddings.size()))
    logger.info("\t* dec: %s" % str(filtered_dec_embeddings.size()))

    enc_output_file = opt.output_file + ".enc.pt"
    dec_output_file = opt.output_file + ".dec.pt"
    logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s"
                % (enc_output_file, dec_output_file))
    torch.save(filtered_enc_embeddings, enc_output_file)
    torch.save(filtered_dec_embeddings, dec_output_file)
    logger.info("\nDone.")
    def __init__(self, model_filename, cmdline_args):
        parser = argparse.ArgumentParser(
            description='translate.py',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        opts.add_md_help_argument(parser)
        opts.translate_opts(parser)
        opt = parser.parse_args(['-model', model_filename, '-src', ''] +
                                (cmdline_args or []))

        translator = make_translator(opt)
        model = translator.model
        fields = translator.fields
        tgt_vocab = fields["tgt"].vocab

        def encode_from_src(src):
            enc_states, memory_bank = model.encoder(src)
            return dict(enc_states=enc_states,
                        memory_bank=memory_bank,
                        src=src)

        @lru_cache(maxsize=32)
        def encode_text(in_text):
            text_preproc = fields['src'].preprocess(in_text)
            src, src_len = fields['src'].process([text_preproc],
                                                 device=-1,
                                                 train=False)
            src = src.unsqueeze(2)  # not sure why
            return encode_from_src(src)

        @lru_cache(maxsize=32)
        def encode_img(image_idx):
            if isinstance(image_idx, str):
                image_idx = int(image_idx)
            src = Variable(torch.IntTensor([image_idx]), volatile=True)
            return encode_from_src(src)

        def encode(inp):
            if model.encoder.__class__.__name__ == 'VecsEncoder':
                return encode_img(inp)
            else:
                return encode_text(inp)

        @lru_cache(maxsize=128)
        def get_decoder_state(in_text, tokens_so_far):
            encoder_out = encode(in_text)
            enc_states = encoder_out['enc_states']
            memory_bank = encoder_out['memory_bank']
            src = encoder_out['src']

            if len(tokens_so_far) == 0:
                return None, translator.model.decoder.init_decoder_state(
                    src, memory_bank, enc_states)

            prev_out, prev_state = get_decoder_state(in_text,
                                                     tokens_so_far[:-1])

            tgt_in = Variable(torch.LongTensor(
                [tgt_vocab.stoi[tokens_so_far[-1]]]),
                              volatile=True)  # [tgt_len]
            tgt_in = tgt_in.unsqueeze(1)  # [tgt_len x batch=1]
            tgt_in = tgt_in.unsqueeze(1)  # [tgt_len x batch=1 x nfeats=1]

            # Prepare to call the decoder. Unfortunately the decoder mutates the state passed in!
            memory_bank = copy.deepcopy(memory_bank)
            assert isinstance(prev_state.hidden, tuple)
            prev_state.hidden = tuple(v.detach() for v in prev_state.hidden)
            prev_state = copy.deepcopy(prev_state)

            assert memory_bank.size()[1] == 1

            dec_out, dec_states, attn = translator.model.decoder(
                tgt_in, memory_bank, prev_state)

            assert dec_out.shape[0] == 1
            return dec_out[0], dec_states

        def generate_completions(in_text, tokens_so_far):
            tokens_so_far = [onmt.io.BOS_WORD] + tokens_so_far
            tokens_so_far = tuple(tokens_so_far)  # Make it hashable
            dec_out, dec_states = get_decoder_state(in_text, tokens_so_far)
            logits = model.generator.forward(dec_out).data
            vocab = tgt_vocab.itos

            assert logits.shape[0] == 1
            logits = logits[0]
            return logits, vocab

        def eval_logprobs(in_text, tokens, *, use_eos):
            encoder_out = encode(in_text)
            enc_states = encoder_out['enc_states']
            memory_bank = encoder_out['memory_bank']
            src = encoder_out['src']

            tokens = [onmt.io.BOS_WORD] + tokens
            if use_eos:
                tokens = tokens + [onmt.io.EOS_WORD]

            decoder_state = model.decoder.init_decoder_state(
                src, memory_bank=memory_bank, encoder_final=enc_states)
            tgt = Variable(
                torch.LongTensor([tgt_vocab.stoi[tok] for tok in tokens
                                  ]).unsqueeze(1).unsqueeze(1))
            dec_out, dec_states, attn = model.decoder(tgt[:-1], memory_bank,
                                                      decoder_state)
            logits = model.generator(dec_out)
            return F.nll_loss(logits.squeeze(1),
                              tgt[1:].squeeze(1).squeeze(1),
                              reduce=False,
                              size_average=False).data.numpy()

        self.model = model
        self.fields = fields
        self.translator = translator
        self.encode = encode
        self.get_decoder_state = get_decoder_state
        self.generate_completions = generate_completions
        self.eval_logprobs = eval_logprobs
示例#13
0
                  os.path.join(temp, "data", "train_target.txt"),
                  "wt") as tofd:
        sofd.write("\n".join([x["source"] for x in train_data]) + "\n")
        tofd.write("\n".join([x["target"] for x in train_data]) + "\n")

    with open(os.path.join(temp, "data", "dev_source.txt"),
              "wt") as sofd, open(os.path.join(temp, "data", "dev_target.txt"),
                                  "wt") as tofd:
        sofd.write("\n".join([x["source"] for x in dev_data]) + "\n")
        tofd.write("\n".join([x["target"] for x in dev_data]) + "\n")

    preproc_parser = argparse.ArgumentParser(
        description='vivisect example',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    opts.add_md_help_argument(preproc_parser)
    opts.preprocess_opts(preproc_parser)

    preproc_args = preproc_parser.parse_args(args=[
        "-train_src",
        os.path.join(temp, "data", "train_source.txt"), "-train_tgt",
        os.path.join(temp, "data", "train_target.txt"), "-valid_src",
        os.path.join(temp, "data", "dev_target.txt"), "-valid_tgt",
        os.path.join(temp, "data", "dev_target.txt"), "-save_data",
        os.path.join(temp, "data", "out")
    ])
    preproc_args.shuffle = 0
    preproc_args.src_seq_length = source_max
    preproc_args.tgt_seq_length = target_max

    train_parser = argparse.ArgumentParser(