示例#1
0
文件: train_lm.py 项目: zpppy/LD-Net
    pw.info('Loading dataset.')
    dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb'))
    w_map, test_data, range_idx = dataset['w_map'], dataset['test_data'], dataset['range']
    train_loader = LargeDataset(args.dataset_folder, range_idx, args.batch_size, args.sequence_length)
    test_loader = EvalDataset(test_data, args.batch_size)

    pw.info('Building models.')
    rnn_map = {'Basic': BasicRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = args.layer_drop)}
    rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate)
    cut_off = args.cut_off + [len(w_map) + 1]
    if args.label_dim > 0:
        soft_max = AdaptiveSoftmax(args.label_dim, cut_off)
    else:
        soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off)
    lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu)
    lm_model.rand_ini()

    pw.info('Building optimizer.')
    optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta}
    if args.lr > 0:
        optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr)
    else:
        optimizer=optim_map[args.update](lm_model.parameters())

    if args.restore_checkpoint:
        if os.path.isfile(args.restore_checkpoint):
            pw.info("loading checkpoint: '{}'".format(args.restore_checkpoint))
            model_file = wrapper.restore_checkpoint(args.restore_checkpoint)['model']
            lm_model.load_state_dict(model_file, False)
        else:
示例#2
0
    device = torch.device("cuda:" + str(gpu_index) if gpu_index >= 0 else "cpu")
    if gpu_index >= 0:
        torch.cuda.set_device(gpu_index)

    pw.info('Loading data from {}.'.format(args.corpus))

    dataset = pickle.load(open(args.corpus, 'rb'))
    name_list = ['flm_map', 'blm_map', 'gw_map', 'c_map', 'y_map', 'emb_array', 'train_data', 'test_data', 'dev_data']
    flm_map, blm_map, gw_map, c_map, y_map, emb_array, train_data, test_data, dev_data = [dataset[tup] for tup in name_list ]

    pw.info('Building language models and seuqence labeling models.')

    rnn_map = {'Basic': BasicRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = 0)}
    flm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate)
    blm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate)
    flm_model = LM(flm_rnn_layer, None, len(flm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim)
    blm_model = LM(blm_rnn_layer, None, len(blm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim)
    flm_model_seq = SparseSeqLM(flm_model, False, args.lm_droprate, False)
    blm_model_seq = SparseSeqLM(blm_model, True, args.lm_droprate, False)
    SL_map = {'vanilla':Vanilla_SeqLabel, 'lm-aug': SeqLabel}
    seq_model = SL_map[args.seq_model](flm_model_seq, blm_model_seq, len(c_map), args.seq_c_dim, args.seq_c_hid, args.seq_c_layer, len(gw_map), args.seq_w_dim, args.seq_w_hid, args.seq_w_layer, len(y_map), args.seq_droprate, unit=args.seq_rnn_unit)

    pw.info('Loading pre-trained models from {}.'.format(args.load_seq))

    seq_file = wrapper.restore_checkpoint(args.load_seq)['model']
    seq_model.load_state_dict(seq_file)
    seq_model.to(device)
    crit = CRFLoss(y_map)
    decoder = CRFDecode(y_map)
    evaluator = eval_wc(decoder, 'f1')
def main():
    global best_ppl
    
    print('loading dataset')
    dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb'))
    w_map, test_data, range_idx = dataset['w_map'], dataset['test_data'], dataset['range']

    cut_off = args.cut_off + [len(w_map) + 1]

    train_loader = LargeDataset(args.dataset_folder, range_idx, args.batch_size, args.sequence_length)
    test_loader = EvalDataset(test_data, args.batch_size)

    print('building model')

    rnn_map = {'Basic': BasicRNN, 'DDNet': DDRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = args.layer_drop)}
    rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate)

    if args.label_dim > 0:
        soft_max = AdaptiveSoftmax(args.label_dim, cut_off)
    else:
        soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off)

    lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu)
    lm_model.rand_ini()
    # lm_model.cuda()
    
    # set up optimizers
    optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta, 'SGD': functools.partial(optim.SGD, momentum=0.9), 'LSRAdam':LSRAdam, 'LSAdam': LSAdam, 'AdamW': AdamW, 'RAdam': RAdam, 'SRAdamW': SRAdamW, 'SRRAdam': SRRAdam}
    if args.update.lower() == 'lsradam' or args.update.lower == 'lsadam':
            optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr*((1.+4.*args.sigma)**(0.25)), 
                           betas=(args.beta1, args.beta2),
                           weight_decay=args.weight_decay, 
                           sigma=args.sigma) 
    elif args.update.lower() == 'radam':
        optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay)
    elif args.update.lower() == 'adamw':
        optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, warmup=args.warmup)
    elif args.update.lower() == 'sradamw':
        iter_count = 1
        optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = args.warmup, restarting_iter=args.restart_schedule[0]) 
    elif args.update.lower() == 'srradam':
        #NOTE: need to double-check this
        iter_count = 1
        optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = args.warmup, restarting_iter=args.restart_schedule[0]) 
    else:
        if args.lr > 0:
            optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr)
        else:
            optimizer=optim_map[args.update](lm_model.parameters())
            
    # Resume
    title = 'onebillionword-' + args.rnn_layer
    logger = Logger(os.path.join(args.checkpath, 'log.txt'), title=title)
    logger.set_names(['Learning Rate', 'Train Loss', 'Train PPL', 'Valid PPL'])
    
    if args.load_checkpoint:
        if os.path.isfile(args.load_checkpoint):
            print("loading checkpoint: '{}'".format(args.load_checkpoint))
            checkpoint_file = torch.load(args.load_checkpoint, map_location=lambda storage, loc: storage)
            lm_model.load_state_dict(checkpoint_file['lm_model'], False)
            optimizer.load_state_dict(checkpoint_file['opt'], False)
        else:
            print("no checkpoint found at: '{}'".format(args.load_checkpoint))

    test_lm = nn.NLLLoss()
    
    test_lm.cuda()
    lm_model.cuda()
    
    batch_index = 0
    epoch_loss = 0
    full_epoch_loss = 0
    best_train_ppl = float('inf')
    cur_lr = args.lr
    
    schedule_index = 1

    try:
        for indexs in range(args.epoch):

            print('#' * 89)
            print('Start: {}'.format(indexs))

            if args.optimizer.lower() == 'sradamw':
                if indexs in args.schedule:
                    optimizer = SRAdamW(lm_model.parameters(), lr=args.lr * (args.gamma**schedule_index), betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = 0, restarting_iter=args.restart_schedule[schedule_index])
                    schedule_index += 1

            elif args.optimizer.lower() == 'srradam':
                if indexs in args.schedule:
                    optimizer = SRRAdam(lm_model.parameters(), lr=args.lr * (args.gamma**schedule_index), betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = 0, restarting_iter=args.restart_schedule[schedule_index])
                    schedule_index += 1
            
            else:
                adjust_learning_rate(optimizer, indexs)
                
            logger.file.write('\nEpoch: [%d | %d] LR: %f' % (indexs + 1, args.epoch, state['lr']))
            
            iterator = train_loader.get_tqdm()
            full_epoch_loss = 0

            lm_model.train()

            for word_t, label_t in iterator:

                if 1 == train_loader.cur_idx:
                    lm_model.init_hidden()

                label_t = label_t.view(-1)

                lm_model.zero_grad()
                loss = lm_model(word_t, label_t)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm(lm_model.parameters(), args.clip)
                optimizer.step()
                
                if args.optimizer.lower() == 'sradamw' or args.optimizer.lower() == 'srradam'
                    iter_count, iter_total = optimizer.update_iter()

                batch_index += 1 

                if 0 == batch_index % args.interval:
                    s_loss = utils.to_scalar(loss)
                    writer.add_scalars('loss_tracking/train_loss', {args.model_name:s_loss}, batch_index)
                
                epoch_loss += utils.to_scalar(loss)
                full_epoch_loss += utils.to_scalar(loss)
                if 0 == batch_index % args.check_interval:
                    epoch_ppl = math.exp(epoch_loss / args.check_interval)
                    writer.add_scalars('loss_tracking/train_ppl', {args.model_name: epoch_ppl}, batch_index)
                    print('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index))
                    logger.file.write('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index))
                    epoch_loss = 0
    
            test_ppl = evaluate(test_loader, lm_model, test_lm, -1)
        
            is_best = test_ppl < best_ppl
            best_ppl = min(test_ppl, best_ppl)

            writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, indexs)
            print('test_ppl: {} @ index: {}'.format(test_ppl, indexs))
            logger.file.write('test_ppl: {} @ index: {}'.format(test_ppl, indexs))
            
            save_checkpoint({
                'epoch': epoch + 1,
                'schedule_index': schedule_index,
                'lm_model': lm_model.state_dict(),
                'ppl': test_ppl,
                'best_ppl': best_ppl,
                'opt':optimizer.state_dict(),
            }, is_best, indexs, checkpoint=args.checkpath)

    except KeyboardInterrupt:

        print('Exiting from training early')
        logger.file.write('Exiting from training early')
        test_ppl = evaluate(test_loader, lm_model, test_lm, -1)
        writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, args.epoch)
        
        is_best=False
        save_checkpoint({
                'epoch': epoch + 1,
                'schedule_index': schedule_index,
                'lm_model': lm_model.state_dict(),
                'ppl': test_ppl,
                'best_ppl': best_ppl,
                'opt':optimizer.state_dict(),
            }, is_best, indexs, checkpoint=args.checkpath)
    
    print('Best PPL:%f'%best_ppl)
    
    logger.file.write('Best PPL:%f'%best_ppl)  
    logger.close()
    
    with open("./all_results.txt", "a") as f:
        fcntl.flock(f, fcntl.LOCK_EX)
        f.write("%s\n"%args.checkpath)
        f.write("best_ppl %f\n\n"%best_ppl)
        fcntl.flock(f, fcntl.LOCK_UN)
示例#4
0
    device = torch.device("cuda:" + str(gpu_index) if gpu_index >= 0 else "cpu")
    if gpu_index >= 0:
        torch.cuda.set_device(gpu_index)

    logger.info('Loading data')

    dataset = pickle.load(open(args.corpus, 'rb'))
    name_list = ['flm_map', 'blm_map', 'gw_map', 'c_map', 'y_map', 'emb_array', 'train_data', 'test_data', 'dev_data']
    flm_map, blm_map, gw_map, c_map, y_map, emb_array, train_data, test_data, dev_data = [dataset[tup] for tup in name_list ]

    logger.info('Loading language model')

    rnn_map = {'Basic': BasicRNN}
    flm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate)
    blm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate)
    flm_model = LM(flm_rnn_layer, None, len(flm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim)
    blm_model = LM(blm_rnn_layer, None, len(blm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim)
    flm_file = wrapper.restore_checkpoint(args.forward_lm)['model']
    flm_model.load_state_dict(flm_file, False)
    blm_file = wrapper.restore_checkpoint(args.backward_lm)['model']
    blm_model.load_state_dict(blm_file, False)
    flm_model_seq = ElmoLM(flm_model, False, args.lm_droprate, True)
    blm_model_seq = ElmoLM(blm_model, True, args.lm_droprate, True)

    logger.info('Building model')

    SL_map = {'vanilla':Vanilla_SeqLabel, 'lm-aug': SeqLabel}
    seq_model = SL_map[args.seq_model](flm_model_seq, blm_model_seq, len(c_map), args.seq_c_dim, args.seq_c_hid, args.seq_c_layer, len(gw_map), args.seq_w_dim, args.seq_w_hid, args.seq_w_layer, len(y_map), args.seq_droprate, unit=args.seq_rnn_unit)
    seq_model.rand_init()
    seq_model.load_pretrained_word_embedding(torch.FloatTensor(emb_array))
    seq_model.to(device)