示例#1
0
def main(args, tokenize):
    args.log_path = args.log_path + args.data_name + '_' + args.model_name + '.log'
    data_path = args.data_path + args.data_name + '.pt'
    standard_data_path = args.data_path + args.data_name + '_standard.pt'

    # init logger
    logger = utils.get_logger(args.log_path)

    # load data
    logger.info('loading data......')
    total_data = torch.load(data_path)
    standard_data = torch.load(standard_data_path)
    train_data = total_data['train']
    dev_data = total_data['dev']
    test_data = total_data['test']
    dev_standard = standard_data['dev']
    test_standard = standard_data['test']

    # init model
    logger.info('initial model......')
    model = Model.BERTModel(args)
    if args.ifgpu:
        model = model.cuda()

    # print args
    logger.info(args)

    if args.mode == 'test':
        logger.info('start testing......')
        test_dataset = Data.ReviewDataset(train_data, dev_data, test_data,
                                          'test')
        # load checkpoint
        logger.info('loading checkpoint......')
        checkpoint = torch.load(args.checkpoint_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()

        batch_generator_test = Data.generate_fi_batches(dataset=test_dataset,
                                                        batch_size=1,
                                                        shuffle=False,
                                                        ifgpu=args.ifgpu)
        # eval
        logger.info('evaluating......')
        f1 = test(model, tokenize, batch_generator_test, test_standard,
                  args.beta, logger)

    elif args.mode == 'train':
        args.save_model_path = args.save_model_path + args.data_name + '_' + args.model_name + '.pth'
        train_dataset = Data.ReviewDataset(train_data, dev_data, test_data,
                                           'train')
        dev_dataset = Data.ReviewDataset(train_data, dev_data, test_data,
                                         'dev')
        test_dataset = Data.ReviewDataset(train_data, dev_data, test_data,
                                          'test')
        batch_num_train = train_dataset.get_batch_num(args.batch_size)

        # optimizer
        logger.info('initial optimizer......')
        param_optimizer = list(model.named_parameters())
        optimizer_grouped_parameters = [{
            'params': [p for n, p in param_optimizer if "_bert" in n],
            'weight_decay':
            0.01
        }, {
            'params': [p for n, p in param_optimizer if "_bert" not in n],
            'lr':
            args.learning_rate,
            'weight_decay':
            0.01
        }]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.tuning_bert_rate,
                          correct_bias=False)

        # load saved model, optimizer and epoch num
        if args.reload and os.path.exists(args.checkpoint_path):
            checkpoint = torch.load(args.checkpoint_path)
            model.load_state_dict(checkpoint['net'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch'] + 1
            logger.info(
                'Reload model and optimizer after training epoch {}'.format(
                    checkpoint['epoch']))
        else:
            start_epoch = 1
            logger.info('New model and optimizer from epoch 0')

        # scheduler
        training_steps = args.epoch_num * batch_num_train
        warmup_steps = int(training_steps * args.warm_up)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=training_steps)

        # training
        logger.info('begin training......')
        best_dev_f1 = 0.
        for epoch in range(start_epoch, args.epoch_num + 1):
            model.train()
            model.zero_grad()

            batch_generator = Data.generate_fi_batches(
                dataset=train_dataset,
                batch_size=args.batch_size,
                ifgpu=args.ifgpu)

            for batch_index, batch_dict in enumerate(batch_generator):

                optimizer.zero_grad()

                # q1_a
                f_aspect_start_scores, f_aspect_end_scores = model(
                    batch_dict['forward_asp_query'],
                    batch_dict['forward_asp_query_mask'],
                    batch_dict['forward_asp_query_seg'], 0)
                f_asp_loss = utils.calculate_entity_loss(
                    f_aspect_start_scores, f_aspect_end_scores,
                    batch_dict['forward_asp_answer_start'],
                    batch_dict['forward_asp_answer_end'])
                # q1_b
                b_opi_start_scores, b_opi_end_scores = model(
                    batch_dict['backward_opi_query'],
                    batch_dict['backward_opi_query_mask'],
                    batch_dict['backward_opi_query_seg'], 0)
                b_opi_loss = utils.calculate_entity_loss(
                    b_opi_start_scores, b_opi_end_scores,
                    batch_dict['backward_opi_answer_start'],
                    batch_dict['backward_opi_answer_end'])
                # q2_a
                f_opi_start_scores, f_opi_end_scores = model(
                    batch_dict['forward_opi_query'].view(
                        -1, batch_dict['forward_opi_query'].size(-1)),
                    batch_dict['forward_opi_query_mask'].view(
                        -1, batch_dict['forward_opi_query_mask'].size(-1)),
                    batch_dict['forward_opi_query_seg'].view(
                        -1, batch_dict['forward_opi_query_seg'].size(-1)), 0)
                f_opi_loss = utils.calculate_entity_loss(
                    f_opi_start_scores, f_opi_end_scores,
                    batch_dict['forward_opi_answer_start'].view(
                        -1, batch_dict['forward_opi_answer_start'].size(-1)),
                    batch_dict['forward_opi_answer_end'].view(
                        -1, batch_dict['forward_opi_answer_end'].size(-1)))
                # q2_b
                b_asp_start_scores, b_asp_end_scores = model(
                    batch_dict['backward_asp_query'].view(
                        -1, batch_dict['backward_asp_query'].size(-1)),
                    batch_dict['backward_asp_query_mask'].view(
                        -1, batch_dict['backward_asp_query_mask'].size(-1)),
                    batch_dict['backward_asp_query_seg'].view(
                        -1, batch_dict['backward_asp_query_seg'].size(-1)), 0)
                b_asp_loss = utils.calculate_entity_loss(
                    b_asp_start_scores, b_asp_end_scores,
                    batch_dict['backward_asp_answer_start'].view(
                        -1, batch_dict['backward_asp_answer_start'].size(-1)),
                    batch_dict['backward_asp_answer_end'].view(
                        -1, batch_dict['backward_asp_answer_end'].size(-1)))
                # q_3
                sentiment_scores = model(
                    batch_dict['sentiment_query'].view(
                        -1, batch_dict['sentiment_query'].size(-1)),
                    batch_dict['sentiment_query_mask'].view(
                        -1, batch_dict['sentiment_query_mask'].size(-1)),
                    batch_dict['sentiment_query_seg'].view(
                        -1, batch_dict['sentiment_query_seg'].size(-1)), 1)
                sentiment_loss = utils.calculate_sentiment_loss(
                    sentiment_scores, batch_dict['sentiment_answer'].view(-1))

                # loss
                loss_sum = f_asp_loss + f_opi_loss + b_opi_loss + b_asp_loss + args.beta * sentiment_loss
                loss_sum.backward()
                optimizer.step()
                scheduler.step()

                # train logger
                if batch_index % 10 == 0:
                    logger.info(
                        'Epoch:[{}/{}]\t Batch:[{}/{}]\t Loss Sum:{}\t '
                        'forward Loss:{};{}\t backward Loss:{};{}\t Sentiment Loss:{}'
                        .format(epoch, args.epoch_num, batch_index,
                                batch_num_train, round(loss_sum.item(), 4),
                                round(f_asp_loss.item(), 4),
                                round(f_opi_loss.item(), 4),
                                round(b_asp_loss.item(), 4),
                                round(b_opi_loss.item(), 4),
                                round(sentiment_loss.item(), 4)))

            # validation
            batch_generator_dev = Data.generate_fi_batches(dataset=dev_dataset,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           ifgpu=args.ifgpu)
            f1 = test(model, tokenize, batch_generator_dev, dev_standard,
                      args.inference_beta, logger)
            # save model and optimizer
            if f1 > best_dev_f1:
                best_dev_f1 = f1
                logger.info('Model saved after epoch {}'.format(epoch))
                state = {
                    'net': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch
                }
                torch.save(state, args.save_model_path)

            # test
            batch_generator_test = Data.generate_fi_batches(
                dataset=test_dataset,
                batch_size=1,
                shuffle=False,
                ifgpu=args.ifgpu)
            f1 = test(model, tokenize, batch_generator_test, test_standard,
                      args.inference_beta, logger)

    else:
        logger.info('Error mode!')
        exit(1)