示例#1
0
def train_val_model(pipeline_cfg, model_cfg, train_cfg):
    data_pipeline = DataPipeline(**pipeline_cfg)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if model_cfg['cxt_emb_pretrained'] is not None:
        model_cfg['cxt_emb_pretrained'] = torch.load(
            model_cfg['cxt_emb_pretrained'])
    bidaf = BiDAF(word_emb=data_pipeline.word_type.vocab.vectors, **model_cfg)
    ema = EMA(train_cfg['exp_decay_rate'])
    for name, param in bidaf.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, bidaf.parameters())
    optimizer = optim.Adadelta(parameters, lr=train_cfg['lr'])
    criterion = nn.CrossEntropyLoss()

    result = {'best_f1': 0.0, 'best_model': None}

    num_epochs = train_cfg['num_epochs']
    for epoch in range(1, num_epochs + 1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)
        for phase in ['train', 'val']:
            val_answers = dict()
            val_f1 = 0
            val_em = 0
            val_cnt = 0
            val_r = 0

            if phase == 'train':
                bidaf.train()
            else:
                bidaf.eval()
                backup_params = EMA(0)
                for name, param in bidaf.named_parameters():
                    if param.requires_grad:
                        backup_params.register(name, param.data)
                        param.data.copy_(ema.get(name))

            with torch.set_grad_enabled(phase == 'train'):
                for batch_num, batch in enumerate(
                        data_pipeline.data_iterators[phase]):
                    optimizer.zero_grad()
                    p1, p2 = bidaf(batch)
                    loss = criterion(p1, batch.s_idx) + criterion(
                        p2, batch.e_idx)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        for name, param in bidaf.named_parameters():
                            if param.requires_grad:
                                ema.update(name, param.data)
                        if batch_num % train_cfg['batch_per_disp'] == 0:
                            batch_loss = loss.item()
                            print('batch %d: loss %.3f' %
                                  (batch_num, batch_loss))

                    if phase == 'val':
                        batch_size, c_len = p1.size()
                        val_cnt += batch_size
                        ls = nn.LogSoftmax(dim=1)
                        mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1). \
                            unsqueeze(0).expand(batch_size, -1, -1)
                        score = (ls(p1).unsqueeze(2) +
                                 ls(p2).unsqueeze(1)) + mask
                        score, s_idx = score.max(dim=1)
                        score, e_idx = score.max(dim=1)
                        s_idx = torch.gather(s_idx, 1,
                                             e_idx.view(-1, 1)).squeeze()

                        for i in range(batch_size):
                            answer = (s_idx[i], e_idx[i])
                            gt = (batch.s_idx[i], batch.e_idx[i])
                            val_f1 += f1_score(answer, gt)
                            val_em += exact_match_score(answer, gt)
                            val_r += r_score(answer, gt)

            if phase == 'val':
                val_f1 = val_f1 * 100 / val_cnt
                val_em = val_em * 100 / val_cnt
                val_r = val_r * 100 / val_cnt
                print('Epoch %d: %s f1 %.3f | %s em %.3f |  %s rouge %.3f' %
                      (epoch, phase, val_f1, phase, val_em, phase, val_r))
                if val_f1 > result['best_f1']:
                    result['best_f1'] = val_f1
                    result['best_em'] = val_em
                    result['best_model'] = copy.deepcopy(bidaf.state_dict())
                    torch.save(result, train_cfg['ckpoint_file'])
                    # with open(train_cfg['val_answers'], 'w', encoding='utf-8') as f:
                    #     print(json.dumps(val_answers), file=f)
                for name, param in bidaf.named_parameters():
                    if param.requires_grad:
                        param.data.copy_(backup_params.get(name))

#resume
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
if os.path.isfile(args.resume):
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    args.start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(args.resume))

ema = EMA(0.999)
for name, param in model.named_parameters():
    if param.requires_grad:
        ema.register(name, param.data)

print('parameters-----')
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data.size())

if args.test == 1:
    print('Test mode')
    test(model, test_data)
else:
    print('Train mode')
    train(model, train_data, optimizer, ema, start_epoch=args.start_epoch)
print('finish')
示例#3
0
class SOLVER():
    def __init__(self, args):
        self.args = args
        self.device = torch.device("cuda:{}".format(self.args.GPU) if torch.
                                   cuda.is_available() else "cpu")
        self.data = READ(self.args)
        glove = self.data.WORD.vocab.vectors
        char_size = len(self.data.CHAR.vocab)

        self.model = BiDAF(self.args, char_size, glove).to(self.device)
        self.optimizer = optim.Adadelta(self.model.parameters(),
                                        lr=self.args.Learning_Rate)
        self.ema = EMA(self.args.Exp_Decay_Rate)

        if APEX_AVAILABLE:  # Mixed Precision
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O2')

        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.ema.register(name, param.data)

        self.parameters = filter(lambda p: p.requires_grad,
                                 self.model.parameters())

    def train(self):

        criterion = nn.NLLLoss()
        criterion = criterion.to(self.device)

        self.model.train()

        max_dev_em, max_dev_f1 = -1, -1
        num_batches = len(self.data.train_iter)

        logging.info("Begin Training")

        self.model.zero_grad()

        loss = 0.0

        for epoch in range(self.args.Epoch):

            self.model.train()

            for i, batch in enumerate(self.data.train_iter):

                i += 1
                p1, p2 = self.model(batch)
                batch_loss = criterion(
                    p1, batch.start_idx.to(self.device)) + criterion(
                        p2, batch.end_idx.to(self.device))

                if APEX_AVAILABLE:
                    with amp.scale_loss(batch_loss,
                                        self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    batch_loss.backward()
                loss = batch_loss.item()

                self.optimizer.step()
                del p1, p2, batch_loss

                for name, param in self.model.named_parameters():
                    if param.requires_grad:
                        self.ema.update(name, param.data)

                self.model.zero_grad()

                logging.info("Epoch [{}/{}] Step [{}/{}] Train Loss {}".format(epoch+1, self.args.Epoch, \
                                                                               i, int(num_batches) +1, round(loss,3)))

                if epoch > 7:
                    if i % 100 == 0:
                        dev_em, dev_f1 = self.evaluate()
                        logging.info("Epoch [{}/{}] Dev EM {} Dev F1 {}".format(epoch + 1, self.args.Epoch, \
                                                                                        round(dev_em,3), round(dev_f1,3)))
                        self.model.train()

                        if dev_f1 > max_dev_f1:
                            max_dev_f1 = dev_f1
                            max_dev_em = dev_em

            dev_em, dev_f1 = self.evaluate()
            logging.info("Epoch [{}/{}] Dev EM {} Dev F1 {}".format(epoch + 1, self.args.Epoch, \
                                                                               round(dev_em,3), round(dev_f1,3)))
            self.model.train()

            if dev_f1 > max_dev_f1:
                max_dev_f1 = dev_f1
                max_dev_em = dev_em

        logging.info('Max Dev EM: {} Max Dev F1: {}'.format(
            round(max_dev_em, 3), round(max_dev_f1, 3)))

    def evaluate(self):

        logging.info("Evaluating on Dev Dataset")
        answers = dict()

        self.model.eval()

        temp_ema = EMA(0)

        for name, param in self.model.named_parameters():
            if param.requires_grad:
                temp_ema.register(name, param.data)
                param.data.copy_(self.ema.get(name))

        with torch.no_grad():
            for _, batch in enumerate(self.data.dev_iter):

                p1, p2 = self.model(batch)
                batch_size, _ = p1.size()

                _, s_idx = p1.max(dim=1)
                _, e_idx = p2.max(dim=1)

                for i in range(batch_size):
                    qid = batch.qid[i]
                    answer = batch.c_word[0][i][s_idx[i]:e_idx[i] + 1]
                    answer = ' '.join(
                        [self.data.WORD.vocab.itos[idx] for idx in answer])
                    answers[qid] = answer

            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    param.data.copy_(temp_ema.get(name))

        results = evaluate(self.args, answers)

        return results['exact_match'], results['f1']
示例#4
0
def train_bidaf(args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model = BiDAF(args, data.WORD.vocab.vectors).to(device)

    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter(logdir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_exact, max_dev_f1 = -1, -1

    iterator = data.train_iter
    for i, batch in tqdm(enumerate(iterator)):
        present_epoch = int(iterator.epoch)
        if present_epoch == args.epoch:
            break
        if present_epoch > last_epoch:
            print('epoch:', present_epoch + 1)
        last_epoch = present_epoch

        p1, p2 = model(batch)

        optimizer.zero_grad()
        # print(p1, batch.s_idx)
        # print(p2, batch.e_idx)
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        # print('p1', p1.shape, p1)
        # print('batch.s_idx', batch.s_idx.shape, batch.s_idx.shape)
        # print(loss, batch_loss.item())
        loss += batch_loss.item()
        # print(loss)
        # print(batch_loss.item())
        batch_loss.backward()
        optimizer.step()

        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.update(name, param.data)

        if (i + 1) % args.print_freq == 0:
            dev_loss, dev_exact, dev_f1 = test(model, ema, args, data)
            c = (i + 1) // args.print_freq

            writer.add_scalar('loss/train', loss, c)
            writer.add_scalar('loss/dev', dev_loss, c)
            writer.add_scalar('exact_match/dev', dev_exact, c)
            writer.add_scalar('f1/dev', dev_f1, c)
            print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}'
                  f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}')

            if dev_f1 > max_dev_f1:
                max_dev_f1 = dev_f1
                max_dev_exact = dev_exact
                best_model = copy.deepcopy(model)

            loss = 0
            model.train()

    writer.close()
    print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}')

    return best_model