コード例 #1
0
ファイル: train.py プロジェクト: scape1989/MAGNET
def main():
    import onlinePreprocess
    onlinePreprocess.seq_length = opt.max_sent_length
    onlinePreprocess.MAX_LDA_WORDS = opt.max_lda_words
    onlinePreprocess.shuffle = 1 if opt.process_shuffle else 0
    from onlinePreprocess import prepare_data_online
    dataset = prepare_data_online(opt.train_src, opt.src_vocab, opt.train_tgt,
                                  opt.tgt_vocab, opt.train_lda, opt.lda_vocab)

    dict_checkpoint = opt.train_from if opt.train_from else opt.train_from_state_dict
    if dict_checkpoint:
        logger.info('Loading dicts from checkpoint at %s' % dict_checkpoint)
        checkpoint = torch.load(dict_checkpoint)
        dataset['dicts'] = checkpoint['dicts']

    trainData = s2s.Dataset(dataset['train']['src'],
                            dataset['train']['eq_mask'],
                            dataset['train']['lda'], dataset['train']['tgt'],
                            opt.batch_size, opt.gpus)
    # validData = s2s.Dataset(dataset['valid']['src'], dataset['valid']['bio'], dataset['valid']['tgt'],
    #                          None, None, opt.batch_size, opt.gpus,
    #                          volatile=True)
    dicts = dataset['dicts']
    logger.info(' * vocabulary size. source = %d; target = %d' %
                (dicts['src'].size(), dicts['tgt'].size()))
    logger.info(' * number of training sentences. %d' %
                len(dataset['train']['src']))
    logger.info(' * maximum batch size. %d' % opt.batch_size)

    logger.info('Building model...')

    encoder = s2s.Models.Encoder(opt, dicts['src'])
    topic_encoder = s2s.Models.TopicEncoder(opt, dicts['lda'])
    decoder = s2s.Models.MPGDecoder(opt, dicts['tgt'])
    decIniter = s2s.Models.DecInit(opt)

    generator = nn.Sequential(
        nn.Linear(opt.dec_rnn_size // opt.maxout_pool_size,
                  dicts['tgt'].size()),  # TODO: fix here
        nn.LogSoftmax(dim=1))

    model = s2s.Models.NMTModel(encoder, topic_encoder, decoder, decIniter)
    model.generator = generator
    translator = s2s.Translator(opt, model, dataset)

    if opt.train_from:
        logger.info('Loading model from checkpoint at %s' % opt.train_from)
        chk_model = checkpoint['model']
        generator_state_dict = chk_model.generator.state_dict()
        model_state_dict = {
            k: v
            for k, v in chk_model.state_dict().items() if 'generator' not in k
        }
        model.load_state_dict(model_state_dict)
        generator.load_state_dict(generator_state_dict)
        opt.start_epoch = checkpoint['epoch'] + 1

    if opt.train_from_state_dict:
        logger.info('Loading model from checkpoint at %s' %
                    opt.train_from_state_dict)
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
        opt.start_epoch = checkpoint['epoch'] + 1

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
    else:
        model.cpu()
        generator.cpu()

    # if len(opt.gpus) > 1:
    #     model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
    #     generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    if not opt.train_from_state_dict and not opt.train_from:
        for pr_name, p in model.named_parameters():
            logger.info(pr_name)
            # p.data.uniform_(-opt.param_init, opt.param_init)
            if p.dim() == 1:
                # p.data.zero_()
                p.data.normal_(0, math.sqrt(6 / (1 + p.size(0))))
            else:
                nn.init.xavier_normal_(p, math.sqrt(3))

        encoder.load_pretrained_vectors(opt)
        decoder.load_pretrained_vectors(opt)

        optim = s2s.Optim(opt.optim,
                          opt.learning_rate,
                          max_grad_norm=opt.max_grad_norm,
                          max_weight_value=opt.max_weight_value,
                          lr_decay=opt.learning_rate_decay,
                          start_decay_at=opt.start_decay_at,
                          decay_bad_count=opt.halve_lr_bad_count)
    else:
        logger.info('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        logger.info(optim)

    optim.set_parameters(model.parameters())

    if opt.train_from or opt.train_from_state_dict:
        optim.optimizer.load_state_dict(
            checkpoint['optim'].optimizer.state_dict())

    validData = None
    if opt.dev_input_src and opt.dev_ref:
        validData = load_dev_data(translator, opt.dev_input_src,
                                  opt.dev_input_lda, opt.dev_ref)
    if opt.test_input_src and opt.test_ref:
        testData = load_dev_data(translator, opt.test_input_src,
                                 opt.test_input_lda, opt.test_ref)
    trainModel(model, translator, trainData, validData, testData, dataset,
               optim)
コード例 #2
0
def main():
    import onlinePreprocess
    onlinePreprocess.lower = opt.lower_input
    onlinePreprocess.seq_length = opt.max_sent_length
    onlinePreprocess.shuffle = 1 if opt.process_shuffle else 0
    from onlinePreprocess import prepare_data_online
    dataset = prepare_data_online(opt.train_src, opt.src_vocab, opt.train_bio, opt.bio_vocab, opt.train_feats,
                                  opt.feat_vocab, opt.train_tgt, opt.tgt_vocab)

    trainData = s2s.Dataset(dataset['train']['src'], dataset['train']['bio'], dataset['train']['feats'],
                            dataset['train']['tgt'],
                            dataset['train']['switch'], dataset['train']['c_tgt'],
                            opt.batch_size, opt.gpus)
    dicts = dataset['dicts']
    logger.info(' * vocabulary size. source = %d; target = %d' %
                (dicts['src'].size(), dicts['tgt'].size()))
    logger.info(' * number of training sentences. %d' %
                len(dataset['train']['src']))
    logger.info(' * maximum batch size. %d' % opt.batch_size)

    logger.info('Building model...')

    encoder = s2s.Models.Encoder(opt, dicts['src'])
    decoder = s2s.Models.Decoder(opt, dicts['tgt'])
    decIniter = s2s.Models.DecInit(opt)

    generator = nn.Sequential(
        nn.Linear(opt.dec_rnn_size // opt.maxout_pool_size, dicts['tgt'].size()),  # TODO: fix here
        # nn.LogSoftmax(dim=1)
        nn.Softmax(dim=1)
    )

    model = s2s.Models.NMTModel(encoder, decoder, decIniter)
    model.generator = generator
    translator = s2s.Translator(opt, model, dataset)

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
    else:
        model.cpu()
        generator.cpu()

    # if len(opt.gpus) > 1:
    #     model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
    #     generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    for pr_name, p in model.named_parameters():
        logger.info(pr_name)
        # p.data.uniform_(-opt.param_init, opt.param_init)
        if p.dim() == 1:
            # p.data.zero_()
            p.data.normal_(0, math.sqrt(6 / (1 + p.size(0))))
        else:
            nn.init.xavier_normal_(p, math.sqrt(3))

    encoder.load_pretrained_vectors(opt)
    decoder.load_pretrained_vectors(opt)

    optim = s2s.Optim(
        opt.optim, opt.learning_rate,
        max_grad_norm=opt.max_grad_norm,
        max_weight_value=opt.max_weight_value,
        lr_decay=opt.learning_rate_decay,
        start_decay_at=opt.start_decay_at,
        decay_bad_count=opt.halve_lr_bad_count
    )
    optim.set_parameters(model.parameters())

    validData = None
    if opt.dev_input_src and opt.dev_ref:
        validData = load_dev_data(translator, opt.dev_input_src, opt.dev_bio, opt.dev_feats, opt.dev_ref)
    trainModel(model, translator, trainData, validData, dataset, optim)
コード例 #3
0
ファイル: train.py プロジェクト: xdong2ps/InjType
def main():

    import onlinePreprocess
    onlinePreprocess.lower = opt.lower_input
    onlinePreprocess.seq_length = opt.max_sent_length
    onlinePreprocess.shuffle = 1 if opt.process_shuffle else 0
    from onlinePreprocess import prepare_data_online

    # opt.train_src (source file of sequence) 'it is a replica of the grotto at lourdes , france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858 .'
    # opt.src_vocab (source file of vocab) 'the(word) 4(index) 256272(frequency) 0.06749202214022335'
    # opt.train_bio (answer position embedding) 'O O O O O O O O O O O O O O O O O O B I I O O O'
    # opt.bio_vocab (source file of answer position vocab) 'O(bio) 4(index) 2525015(frequency) 0.8958601572376024'
    # opt.train_feats (source file of postag/ner/case) 'PERSON/UPCASE/NN ...' (3 different embeddings)
    # opt.feat_vocab (source file of answer feat vocab)
    # opt.train_tgt (source file of question) 'to whom did the virgin mary allegedly appear in 1858 in lourdes france ?'
    # opt.tgt_vocab (source file of vocab) same file with opt.src_vocab !!
    dataset = prepare_data_online(opt.train_src, opt.src_vocab, opt.train_bio,
                                  opt.bio_vocab, opt.train_feats,
                                  opt.feat_vocab, opt.train_tgt, opt.tgt_vocab,
                                  opt.train_guide_src, opt.guide_src_vocab)

    trainData = s2s.Dataset(dataset['train']['src'], dataset['train']['bio'],
                            dataset['train']['feats'], dataset['train']['tgt'],
                            dataset['train']['switch'],
                            dataset['train']['c_tgt'], opt.batch_size,
                            opt.gpus, dataset['train']['guide_src'])

    dicts = dataset['dicts']
    logger.info(' * vocabulary size. source = %d; target = %d' %
                (dicts['src'].size(), dicts['tgt'].size()))
    logger.info(' * number of training sentences. %d' %
                len(dataset['train']['src']))
    logger.info(' * maximum batch size. %d' % opt.batch_size)

    logger.info('Building Model ...')
    encoder = s2s.Models.Encoder(opt, dicts['src'], dicts['guide_src'])
    decoder = s2s.Models.Decoder(opt, dicts['tgt'])
    decIniter = s2s.Models.DecInit(opt)
    ''' generator map output embedding to vocab size vector then softmax'''
    generator = nn.Sequential(
        nn.Linear(opt.dec_rnn_size // opt.maxout_pool_size,
                  dicts['tgt'].size()), nn.Softmax(dim=1))
    classifier = nn.Sequential(
        nn.Linear(opt.dec_rnn_size + 300, dicts['guide_src'].size()),
        nn.Softmax(dim=1))
    nlu_generator = nn.Sequential(
        nn.Linear(opt.dec_rnn_size * 2, dicts['guide_src'].size()),
        nn.Softmax(dim=1))

    model = s2s.Models.NMTModel(encoder, decoder, decIniter)
    model.generator = generator
    model.classifier = classifier
    model.nlu_generator = nlu_generator
    translator = s2s.Translator(opt, model, dataset)

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
        classifier.cuda()
        nlu_generator.cuda()
    else:
        model.cpu()
        generator.cpu()
        classifier.cpu()
        nlu_generator.cpu()

    # if len(opt.gpus) > 1:
    #     model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
    #     generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    for pr_name, p in model.named_parameters():
        logger.info(pr_name)
        # p.data.uniform_(-opt.param_init, opt.param_init)
        if p.dim() == 1:
            # p.data.zero_()
            p.data.normal_(0, math.sqrt(6 / (1 + p.size(0))))
        else:
            nn.init.xavier_normal_(p, math.sqrt(3))

    encoder.load_pretrained_vectors(opt)
    decoder.load_pretrained_vectors(opt)

    optim = s2s.Optim(opt.optim,
                      opt.learning_rate,
                      max_grad_norm=opt.max_grad_norm,
                      max_weight_value=opt.max_weight_value,
                      lr_decay=opt.learning_rate_decay,
                      start_decay_at=opt.start_decay_at,
                      decay_bad_count=opt.halve_lr_bad_count)
    optim.set_parameters(model.parameters())

    validData = None
    if opt.dev_input_src and opt.dev_ref:
        validData = load_dev_data(translator, opt.dev_input_src, opt.dev_bio,
                                  opt.dev_feats, opt.dev_ref,
                                  opt.dev_guide_src)

    testData = None
    if opt.test_input_src and opt.test_ref:
        testData = load_dev_data(translator, opt.test_input_src, opt.test_bio,
                                 opt.test_feats, opt.test_ref,
                                 opt.test_guide_src)

    trainModel(model, translator, trainData, validData, testData, dataset,
               optim)