示例#1
0
文件: Train.py 项目: HMJW/DepSAWR
    print("\nGPU using status: ", config.use_cuda)

    train_files = config.train_files.strip().split(' ')
    train_srcs, train_tgts = read_training_corpus(train_files[0], train_files[1], \
                                                  config.max_src_length, config.max_tgt_length)

    src_vocab, tgt_vocab = creat_vocabularies(train_srcs, train_tgts,
                                              config.src_vocab_size,
                                              config.tgt_vocab_size)
    if args.tgt_word_file is not None:
        tgt_words = read_tgt_words(args.tgt_word_file)
        tgt_vocab = TGTVocab(tgt_words)
    pickle.dump(src_vocab, open(config.save_src_vocab_path, 'wb'))
    pickle.dump(tgt_vocab, open(config.save_tgt_vocab_path, 'wb'))

    print("Sentence Number: #train = %d" % (len(train_srcs)))

    # model
    nmt_model = eval(config.model_name)(config, src_vocab.word_size, src_vocab.rel_size, \
                                        tgt_vocab.size, config.use_cuda)
    critic = NMTCritierion(label_smoothing=config.label_smoothing)

    if config.use_cuda:
        #torch.backends.cudnn.enabled = False
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    nmt = NMTHelper(nmt_model, critic, src_vocab, tgt_vocab, config)

    train(nmt, train_srcs, train_tgts, config)
示例#2
0
    else:
        model_path = config.load_model_path + '.' + args.model_id
        print("####LOAD pretrain model from: " + model_path + "####")
        src_vocab = pickle.load(open(config.save_src_vocab_path, 'rb+'))
        tgt_vocab = pickle.load(open(config.save_tgt_vocab_path, 'rb+'))

        ext_src_emb = src_vocab.load_pretrained_embs(config.src_emb)
        ext_tgt_emb = tgt_vocab.load_pretrained_embs(config.tgt_emb)
        nmt_model = eval(config.model_name)(config, parser_config, src_vocab.size, tgt_vocab.size,
                                            ext_src_emb, ext_tgt_emb, config.use_cuda)

        nmt_model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))

    print("Sentence Number: #train = %d" %(len(train_srcs)))

    # model
    critic = NMTCritierion(label_smoothing=config.label_smoothing)


    if config.use_cuda:
        #torch.backends.cudnn.enabled = False
        parser_model = parser_model.cuda()
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    nmt = NMTHelper(nmt_model, parser_model, critic, src_vocab, tgt_vocab,  parser_vocab, config, parser_config)


    train(nmt, train_srcs, train_tgts, config)