def main(args, tokenize): args.log_path = args.log_path + args.data_name + '_' + args.model_name + '.log' data_path = args.data_path + args.data_name + '.pt' standard_data_path = args.data_path + args.data_name + '_standard.pt' # init logger logger = utils.get_logger(args.log_path) # load data logger.info('loading data......') total_data = torch.load(data_path) standard_data = torch.load(standard_data_path) train_data = total_data['train'] dev_data = total_data['dev'] test_data = total_data['test'] dev_standard = standard_data['dev'] test_standard = standard_data['test'] # init model logger.info('initial model......') model = Model.BERTModel(args) if args.ifgpu: model = model.cuda() # print args logger.info(args) if args.mode == 'test': logger.info('start testing......') test_dataset = Data.ReviewDataset(train_data, dev_data, test_data, 'test') # load checkpoint logger.info('loading checkpoint......') checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint['net']) model.eval() batch_generator_test = Data.generate_fi_batches(dataset=test_dataset, batch_size=1, shuffle=False, ifgpu=args.ifgpu) # eval logger.info('evaluating......') f1 = test(model, tokenize, batch_generator_test, test_standard, args.beta, logger) elif args.mode == 'train': args.save_model_path = args.save_model_path + args.data_name + '_' + args.model_name + '.pth' train_dataset = Data.ReviewDataset(train_data, dev_data, test_data, 'train') dev_dataset = Data.ReviewDataset(train_data, dev_data, test_data, 'dev') test_dataset = Data.ReviewDataset(train_data, dev_data, test_data, 'test') batch_num_train = train_dataset.get_batch_num(args.batch_size) # optimizer logger.info('initial optimizer......') param_optimizer = list(model.named_parameters()) optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if "_bert" in n], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if "_bert" not in n], 'lr': args.learning_rate, 'weight_decay': 0.01 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.tuning_bert_rate, correct_bias=False) # load saved model, optimizer and epoch num if args.reload and os.path.exists(args.checkpoint_path): checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] + 1 logger.info( 'Reload model and optimizer after training epoch {}'.format( checkpoint['epoch'])) else: start_epoch = 1 logger.info('New model and optimizer from epoch 0') # scheduler training_steps = args.epoch_num * batch_num_train warmup_steps = int(training_steps * args.warm_up) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=training_steps) # training logger.info('begin training......') best_dev_f1 = 0. for epoch in range(start_epoch, args.epoch_num + 1): model.train() model.zero_grad() batch_generator = Data.generate_fi_batches( dataset=train_dataset, batch_size=args.batch_size, ifgpu=args.ifgpu) for batch_index, batch_dict in enumerate(batch_generator): optimizer.zero_grad() # q1_a f_aspect_start_scores, f_aspect_end_scores = model( batch_dict['forward_asp_query'], batch_dict['forward_asp_query_mask'], batch_dict['forward_asp_query_seg'], 0) f_asp_loss = utils.calculate_entity_loss( f_aspect_start_scores, f_aspect_end_scores, batch_dict['forward_asp_answer_start'], batch_dict['forward_asp_answer_end']) # q1_b b_opi_start_scores, b_opi_end_scores = model( batch_dict['backward_opi_query'], batch_dict['backward_opi_query_mask'], batch_dict['backward_opi_query_seg'], 0) b_opi_loss = utils.calculate_entity_loss( b_opi_start_scores, b_opi_end_scores, batch_dict['backward_opi_answer_start'], batch_dict['backward_opi_answer_end']) # q2_a f_opi_start_scores, f_opi_end_scores = model( batch_dict['forward_opi_query'].view( -1, batch_dict['forward_opi_query'].size(-1)), batch_dict['forward_opi_query_mask'].view( -1, batch_dict['forward_opi_query_mask'].size(-1)), batch_dict['forward_opi_query_seg'].view( -1, batch_dict['forward_opi_query_seg'].size(-1)), 0) f_opi_loss = utils.calculate_entity_loss( f_opi_start_scores, f_opi_end_scores, batch_dict['forward_opi_answer_start'].view( -1, batch_dict['forward_opi_answer_start'].size(-1)), batch_dict['forward_opi_answer_end'].view( -1, batch_dict['forward_opi_answer_end'].size(-1))) # q2_b b_asp_start_scores, b_asp_end_scores = model( batch_dict['backward_asp_query'].view( -1, batch_dict['backward_asp_query'].size(-1)), batch_dict['backward_asp_query_mask'].view( -1, batch_dict['backward_asp_query_mask'].size(-1)), batch_dict['backward_asp_query_seg'].view( -1, batch_dict['backward_asp_query_seg'].size(-1)), 0) b_asp_loss = utils.calculate_entity_loss( b_asp_start_scores, b_asp_end_scores, batch_dict['backward_asp_answer_start'].view( -1, batch_dict['backward_asp_answer_start'].size(-1)), batch_dict['backward_asp_answer_end'].view( -1, batch_dict['backward_asp_answer_end'].size(-1))) # q_3 sentiment_scores = model( batch_dict['sentiment_query'].view( -1, batch_dict['sentiment_query'].size(-1)), batch_dict['sentiment_query_mask'].view( -1, batch_dict['sentiment_query_mask'].size(-1)), batch_dict['sentiment_query_seg'].view( -1, batch_dict['sentiment_query_seg'].size(-1)), 1) sentiment_loss = utils.calculate_sentiment_loss( sentiment_scores, batch_dict['sentiment_answer'].view(-1)) # loss loss_sum = f_asp_loss + f_opi_loss + b_opi_loss + b_asp_loss + args.beta * sentiment_loss loss_sum.backward() optimizer.step() scheduler.step() # train logger if batch_index % 10 == 0: logger.info( 'Epoch:[{}/{}]\t Batch:[{}/{}]\t Loss Sum:{}\t ' 'forward Loss:{};{}\t backward Loss:{};{}\t Sentiment Loss:{}' .format(epoch, args.epoch_num, batch_index, batch_num_train, round(loss_sum.item(), 4), round(f_asp_loss.item(), 4), round(f_opi_loss.item(), 4), round(b_asp_loss.item(), 4), round(b_opi_loss.item(), 4), round(sentiment_loss.item(), 4))) # validation batch_generator_dev = Data.generate_fi_batches(dataset=dev_dataset, batch_size=1, shuffle=False, ifgpu=args.ifgpu) f1 = test(model, tokenize, batch_generator_dev, dev_standard, args.inference_beta, logger) # save model and optimizer if f1 > best_dev_f1: best_dev_f1 = f1 logger.info('Model saved after epoch {}'.format(epoch)) state = { 'net': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(state, args.save_model_path) # test batch_generator_test = Data.generate_fi_batches( dataset=test_dataset, batch_size=1, shuffle=False, ifgpu=args.ifgpu) f1 = test(model, tokenize, batch_generator_test, test_standard, args.inference_beta, logger) else: logger.info('Error mode!') exit(1)