def evaluate(data_loader, lm_model, criterion, limited = 76800):
    print('evaluating')
    lm_model.eval()

    iterator = data_loader.get_tqdm()

    lm_model.init_hidden()
    total_loss = 0
    total_len = 0
    for word_t, label_t in iterator:
        label_t = label_t.view(-1)
        tmp_len = label_t.size(0)
        output = lm_model.log_prob(word_t)
        total_loss += tmp_len * utils.to_scalar(criterion(autograd.Variable(output), label_t))
        total_len += tmp_len

        if limited >=0 and total_len > limited:
            break

    ppl = math.exp(total_loss / total_len)
    print('PPL: ' + str(ppl))

    return ppl
示例#2
0
文件: train_lm.py 项目: zpppy/LD-Net
                if 1 == train_loader.cur_idx:
                    lm_model.init_hidden()

                label_t = label_t.view(-1)

                lm_model.zero_grad()
                loss = lm_model(word_t, label_t)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(lm_model.parameters(), args.clip)
                optimizer.step()

                batch_index += 1
                if 0 == batch_index % args.interval:
                    s_loss = utils.to_scalar(loss)
                    pw.add_loss_vs_batch({'batch_loss': s_loss}, batch_index, use_logger = False)
                                
                epoch_loss += utils.to_scalar(loss)
                if 0 == batch_index % args.epoch_size:
                    epoch_ppl = math.exp(epoch_loss / args.epoch_size)
                    pw.add_loss_vs_batch({'train_ppl': epoch_ppl}, batch_index, use_logger = True)
                    if epoch_loss < best_train_ppl:
                        best_train_ppl = epoch_loss
                        patience = 0
                    else:
                        patience += 1
                    epoch_loss = 0

                if patience > args.patience and cur_lr > 0:
                    patience = 0
def main():
    global best_ppl
    
    print('loading dataset')
    dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb'))
    w_map, test_data, range_idx = dataset['w_map'], dataset['test_data'], dataset['range']

    cut_off = args.cut_off + [len(w_map) + 1]

    train_loader = LargeDataset(args.dataset_folder, range_idx, args.batch_size, args.sequence_length)
    test_loader = EvalDataset(test_data, args.batch_size)

    print('building model')

    rnn_map = {'Basic': BasicRNN, 'DDNet': DDRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = args.layer_drop)}
    rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate)

    if args.label_dim > 0:
        soft_max = AdaptiveSoftmax(args.label_dim, cut_off)
    else:
        soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off)

    lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu)
    lm_model.rand_ini()
    # lm_model.cuda()
    
    # set up optimizers
    optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta, 'SGD': functools.partial(optim.SGD, momentum=0.9), 'LSRAdam':LSRAdam, 'LSAdam': LSAdam, 'AdamW': AdamW, 'RAdam': RAdam, 'SRAdamW': SRAdamW, 'SRRAdam': SRRAdam}
    if args.update.lower() == 'lsradam' or args.update.lower == 'lsadam':
            optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr*((1.+4.*args.sigma)**(0.25)), 
                           betas=(args.beta1, args.beta2),
                           weight_decay=args.weight_decay, 
                           sigma=args.sigma) 
    elif args.update.lower() == 'radam':
        optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay)
    elif args.update.lower() == 'adamw':
        optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, warmup=args.warmup)
    elif args.update.lower() == 'sradamw':
        iter_count = 1
        optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = args.warmup, restarting_iter=args.restart_schedule[0]) 
    elif args.update.lower() == 'srradam':
        #NOTE: need to double-check this
        iter_count = 1
        optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = args.warmup, restarting_iter=args.restart_schedule[0]) 
    else:
        if args.lr > 0:
            optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr)
        else:
            optimizer=optim_map[args.update](lm_model.parameters())
            
    # Resume
    title = 'onebillionword-' + args.rnn_layer
    logger = Logger(os.path.join(args.checkpath, 'log.txt'), title=title)
    logger.set_names(['Learning Rate', 'Train Loss', 'Train PPL', 'Valid PPL'])
    
    if args.load_checkpoint:
        if os.path.isfile(args.load_checkpoint):
            print("loading checkpoint: '{}'".format(args.load_checkpoint))
            checkpoint_file = torch.load(args.load_checkpoint, map_location=lambda storage, loc: storage)
            lm_model.load_state_dict(checkpoint_file['lm_model'], False)
            optimizer.load_state_dict(checkpoint_file['opt'], False)
        else:
            print("no checkpoint found at: '{}'".format(args.load_checkpoint))

    test_lm = nn.NLLLoss()
    
    test_lm.cuda()
    lm_model.cuda()
    
    batch_index = 0
    epoch_loss = 0
    full_epoch_loss = 0
    best_train_ppl = float('inf')
    cur_lr = args.lr
    
    schedule_index = 1

    try:
        for indexs in range(args.epoch):

            print('#' * 89)
            print('Start: {}'.format(indexs))

            if args.optimizer.lower() == 'sradamw':
                if indexs in args.schedule:
                    optimizer = SRAdamW(lm_model.parameters(), lr=args.lr * (args.gamma**schedule_index), betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = 0, restarting_iter=args.restart_schedule[schedule_index])
                    schedule_index += 1

            elif args.optimizer.lower() == 'srradam':
                if indexs in args.schedule:
                    optimizer = SRRAdam(lm_model.parameters(), lr=args.lr * (args.gamma**schedule_index), betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = 0, restarting_iter=args.restart_schedule[schedule_index])
                    schedule_index += 1
            
            else:
                adjust_learning_rate(optimizer, indexs)
                
            logger.file.write('\nEpoch: [%d | %d] LR: %f' % (indexs + 1, args.epoch, state['lr']))
            
            iterator = train_loader.get_tqdm()
            full_epoch_loss = 0

            lm_model.train()

            for word_t, label_t in iterator:

                if 1 == train_loader.cur_idx:
                    lm_model.init_hidden()

                label_t = label_t.view(-1)

                lm_model.zero_grad()
                loss = lm_model(word_t, label_t)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm(lm_model.parameters(), args.clip)
                optimizer.step()
                
                if args.optimizer.lower() == 'sradamw' or args.optimizer.lower() == 'srradam'
                    iter_count, iter_total = optimizer.update_iter()

                batch_index += 1 

                if 0 == batch_index % args.interval:
                    s_loss = utils.to_scalar(loss)
                    writer.add_scalars('loss_tracking/train_loss', {args.model_name:s_loss}, batch_index)
                
                epoch_loss += utils.to_scalar(loss)
                full_epoch_loss += utils.to_scalar(loss)
                if 0 == batch_index % args.check_interval:
                    epoch_ppl = math.exp(epoch_loss / args.check_interval)
                    writer.add_scalars('loss_tracking/train_ppl', {args.model_name: epoch_ppl}, batch_index)
                    print('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index))
                    logger.file.write('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index))
                    epoch_loss = 0
    
            test_ppl = evaluate(test_loader, lm_model, test_lm, -1)
        
            is_best = test_ppl < best_ppl
            best_ppl = min(test_ppl, best_ppl)

            writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, indexs)
            print('test_ppl: {} @ index: {}'.format(test_ppl, indexs))
            logger.file.write('test_ppl: {} @ index: {}'.format(test_ppl, indexs))
            
            save_checkpoint({
                'epoch': epoch + 1,
                'schedule_index': schedule_index,
                'lm_model': lm_model.state_dict(),
                'ppl': test_ppl,
                'best_ppl': best_ppl,
                'opt':optimizer.state_dict(),
            }, is_best, indexs, checkpoint=args.checkpath)

    except KeyboardInterrupt:

        print('Exiting from training early')
        logger.file.write('Exiting from training early')
        test_ppl = evaluate(test_loader, lm_model, test_lm, -1)
        writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, args.epoch)
        
        is_best=False
        save_checkpoint({
                'epoch': epoch + 1,
                'schedule_index': schedule_index,
                'lm_model': lm_model.state_dict(),
                'ppl': test_ppl,
                'best_ppl': best_ppl,
                'opt':optimizer.state_dict(),
            }, is_best, indexs, checkpoint=args.checkpath)
    
    print('Best PPL:%f'%best_ppl)
    
    logger.file.write('Best PPL:%f'%best_ppl)  
    logger.close()
    
    with open("./all_results.txt", "a") as f:
        fcntl.flock(f, fcntl.LOCK_EX)
        f.write("%s\n"%args.checkpath)
        f.write("best_ppl %f\n\n"%best_ppl)
        fcntl.flock(f, fcntl.LOCK_UN)