示例#1
0
def train(model_params, launch_params):
    with open(launch_params['word_emb_file'], "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(launch_params['char_emb_file'], "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(launch_params['train_eval_file'], "r") as fh:
        train_eval_file = json.load(fh)
    with open(launch_params['dev_eval_file'], "r") as fh:
        dev_eval_file = json.load(fh)

    writer = SummaryWriter(os.path.join(launch_params['log'], launch_params['prefix']))
    
    lr = launch_params['learning_rate']
    base_lr = 1.0
    warm_up = launch_params['lr_warm_up_num']
    model_params['word_mat'] = word_mat
    model_params['char_mat'] = char_mat
    
    logging.info('Load dataset and create model.')
    dev_dataset = SQuADDataset(launch_params['dev_record_file'], launch_params['test_num_batches'], 
                               launch_params['batch_size'], launch_params['word2ind_file'])
    if launch_params['fine_tuning']:
        train_dataset = SQuADDataset(launch_params['train_record_file'], launch_params['fine_tuning_steps'], 
                                    launch_params['batch_size'], launch_params['word2ind_file'])
        model_args = pickle.load(open(launch_params['args_filename'], 'rb'))
        model = RNet(**model_args)
        model.load_state_dict(torch.load(launch_params['dump_filename']))
        model.to(device)
    else:
        train_dataset = SQuADDataset(launch_params['train_record_file'], launch_params['num_steps'], 
                                    launch_params['batch_size'], launch_params['word2ind_file'])
        model = RNet(**model_params).to(device)
        launch_params['fine_tuning_steps'] = 0
    
    params = filter(lambda param: param.requires_grad, model.parameters())
    optimizer = optim.Adam(params, lr=base_lr, betas=(launch_params['beta1'], launch_params['beta2']), eps=1e-7, weight_decay=3e-7)
    cr = lr / log2(warm_up)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ee: cr * log2(ee + 1) if ee < warm_up else lr)
    qt = False
    logging.info('Start training.')
    for iter in range(launch_params['num_steps']):
        try:
            passage_w, passage_c, question_w, question_c, y1, y2, ids = train_dataset[iter]
            passage_w, passage_c = passage_w.to(device), passage_c.to(device)
            question_w, question_c = question_w.to(device), question_c.to(device)
            y1, y2 = y1.to(device), y2.to(device)
            loss, p1, p2 = model.train_step([passage_w, passage_c, question_w, question_c], y1, y2, optimizer, scheduler)
            if iter % launch_params['train_interval'] == 0:
                logging.info('Iteration %d; Loss: %f', iter+launch_params['fine_tuning_steps'], loss)
                writer.add_scalar('Loss', loss, iter+launch_params['fine_tuning_steps'])
            if iter % launch_params['train_sample_interval'] == 0:
                start = torch.argmax(p1[0, :]).item()
                end = torch.argmax(p2[0, start:]).item()+start
                passage = train_dataset.decode(passage_w)
                question = train_dataset.decode(question_w)
                generated_answer = train_dataset.decode(passage_w[:, start:end+1])
                real_answer = train_dataset.decode(passage_w[:, y1[0]:y2[0]+1])
                logging.info('Train Sample:\n Passage: %s\nQuestion: %s\nOriginal answer: %s\nGenerated answer: %s',
                        passage, question, real_answer, generated_answer)
            if iter % launch_params['test_interval'] == 0:
                metrics, _ = evaluation(model, dev_dataset, dev_eval_file, launch_params['test_num_batches'])
                logging.info("TEST loss %f F1 %f EM %f", metrics['loss'], metrics['f1'], metrics['exact_match'])
                writer.add_scalar('Test_loss', metrics['loss'], iter)
                writer.add_scalar('Test_f1', metrics['f1'], iter)
                writer.add_scalar('Test_em', metrics['exact_match'], iter)
        except RuntimeError as e:
            logging.error(str(e))
        except KeyboardInterrupt:
            break
    torch.save(model.cpu().state_dict(), launch_params['dump_filename'])
    pickle.dump(model_params, open(launch_params['args_filename'], 'wb'))
    logging.info('Model has been saved.')
示例#2
0
test_set = get_test_set(opt.upscale_factor, opt.full_size)
training_data_loader = DataLoader(dataset=train_set,
                                  num_workers=opt.threads,
                                  batch_size=opt.batchSize,
                                  shuffle=True)
testing_data_loader = DataLoader(dataset=test_set,
                                 num_workers=opt.threads,
                                 batch_size=opt.testBatchSize,
                                 shuffle=True)

print('===> Building model')
model = RNet(upscale_factor=opt.upscale_factor, full_size=opt.full_size)
model.to(device)
criterion = nn.MSELoss()
#Three optimizers, one for each output
optimizerLow = optim.Adam(model.parameters(), lr=opt.lr)
optimizerInt1 = optim.Adam(model.parameters(), lr=opt.lr)
optimizerInt2 = optim.Adam(model.parameters(), lr=opt.lr)


def train(epoch):
    low_loss = 0
    int1_loss = 0
    int2_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        inimg, int1, int2, target = batch[0].to(device), batch[1].to(
            device), batch[2].to(device), batch[3].to(device)
        epochloss = 0

        #Run through the model, optimizes for each output, from int2 to int1 and finally the lowest resolution input
        optimizerLow.zero_grad()
示例#3
0
@LastEditTime: 2019-11-06 15:37:05
@Update: 
'''
import os
import torch
from torch import nn
from torch import optim
from torch.optim import lr_scheduler

from config import configer
from dataset import MtcnnData
from model import RNet
from model import MtcnnLoss, LossFn
from trainer import MtcnnTrainer

net = RNet()
# state = torch.load('ckptdir/RNet_0025.pkl', map_location='cpu')['net_state']; net.load_state_dict(state)

params = net.parameters()
trainset = MtcnnData(configer.datapath, 24, 'train', save_in_memory=False)
validset = MtcnnData(configer.datapath, 24, 'valid', save_in_memory=False)
testset = MtcnnData(configer.datapath, 24, 'test', save_in_memory=False)
# criterion = MtcnnLoss(1.0, 0.5, 0.0)
criterion = LossFn(1.0, 0.5, 1.0)
optimizer = optim.Adam
lr_scheduler = lr_scheduler.ExponentialLR

trainer = MtcnnTrainer(configer, net, params, trainset, validset, testset,
                       criterion, optimizer, lr_scheduler)
trainer.train()