def main(): # set torch and numpy seed for reproducibility torch.manual_seed(27) np.random.seed(27) # tensorboard writer writer = SummaryWriter(settings.TENSORBOARD_DIR) # makedir snapshot makedir(settings.CHECKPOINT_DIR) # enable cudnn torch.backends.cudnn.enabled = True # create segmentor network model_G = Segmentor(pretrained=settings.PRETRAINED, num_classes=settings.NUM_CLASSES, modality=settings.MODALITY) model_G.train() model_G.cuda() torch.backends.cudnn.benchmark = True # create discriminator network model_D = Discriminator(settings.NUM_CLASSES) model_D.train() model_D.cuda() # dataset and dataloader dataset = TrainDataset() dataloader = data.DataLoader(dataset, batch_size=settings.BATCH_SIZE, shuffle=True, num_workers=settings.NUM_WORKERS, pin_memory=True, drop_last=True) test_dataset = TestDataset(data_root=settings.DATA_ROOT_VAL, data_list=settings.DATA_LIST_VAL) test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=settings.NUM_WORKERS, pin_memory=True) # optimizer for generator network (segmentor) optim_G = optim.SGD(model_G.optim_parameters(settings.LR), lr=settings.LR, momentum=settings.LR_MOMENTUM, weight_decay=settings.WEIGHT_DECAY) # lr scheduler for optimi_G lr_lambda_G = lambda epoch: (1 - epoch / settings.EPOCHS )**settings.LR_POLY_POWER lr_scheduler_G = optim.lr_scheduler.LambdaLR(optim_G, lr_lambda=lr_lambda_G) # optimizer for discriminator network optim_D = optim.Adam(model_D.parameters(), settings.LR_D) # lr scheduler for optimi_D lr_lambda_D = lambda epoch: (1 - epoch / settings.EPOCHS )**settings.LR_POLY_POWER lr_scheduler_D = optim.lr_scheduler.LambdaLR(optim_D, lr_lambda=lr_lambda_D) # losses ce_loss = CrossEntropyLoss2d( ignore_index=settings.IGNORE_LABEL) # to use for segmentor bce_loss = BCEWithLogitsLoss2d() # to use for discriminator # upsampling for the network output upsample = nn.Upsample(size=(settings.CROP_SIZE, settings.CROP_SIZE), mode='bilinear', align_corners=True) # # labels for adversarial training # pred_label = 0 # gt_label = 1 # load the model to resume training last_epoch = -1 if settings.RESUME_TRAIN: checkpoint = torch.load(settings.LAST_CHECKPOINT) model_G.load_state_dict(checkpoint['model_G_state_dict']) model_G.train() model_G.cuda() model_D.load_state_dict(checkpoint['model_D_state_dict']) model_D.train() model_D.cuda() optim_G.load_state_dict(checkpoint['optim_G_state_dict']) optim_D.load_state_dict(checkpoint['optim_D_state_dict']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G_state_dict']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D_state_dict']) last_epoch = checkpoint['epoch'] # purge the logs after the last_epoch writer = SummaryWriter(settings.TENSORBOARD_DIR, purge_step=(last_epoch + 1) * len(dataloader)) for epoch in range(last_epoch + 1, settings.EPOCHS + 1): train_one_epoch(model_G, model_D, optim_G, optim_D, dataloader, test_dataloader, epoch, upsample, ce_loss, bce_loss, writer, print_freq=5, eval_freq=settings.EVAL_FREQ) if epoch % settings.CHECKPOINT_FREQ == 0 and epoch != 0: save_checkpoint(epoch, model_G, model_D, optim_G, optim_D, lr_scheduler_G, lr_scheduler_D) # save the final model if epoch >= settings.EPOCHS: print('saving the final model') save_checkpoint(epoch, model_G, model_D, optim_G, optim_D, lr_scheduler_G, lr_scheduler_D) writer.close() lr_scheduler_G.step() lr_scheduler_D.step()
class Trainer(object): def __init__(self, config, args): self.args = args self.config = config self.visdom = args.visdom if args.visdom: self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888) # Define Dataloader self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( config) self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader( config) # Define network self.model = DeepLab(num_classes=self.nclass, backbone=config.backbone, output_stride=config.out_stride, sync_bn=config.sync_bn, freeze_bn=config.freeze_bn) self.D = Discriminator(num_classes=self.nclass, ndf=16) train_params = [{ 'params': self.model.get_1x_lr_params(), 'lr': config.lr }, { 'params': self.model.get_10x_lr_params(), 'lr': config.lr * config.lr_ratio }] # Define Optimizer self.optimizer = torch.optim.SGD(train_params, momentum=config.momentum, weight_decay=config.weight_decay) self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=config.lr, betas=(0.9, 0.99)) # Define Criterion # whether to use class balanced weights self.criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=config.loss) self.entropy_mini_loss = MinimizeEntropyLoss() self.bottleneck_loss = BottleneckLoss() self.instance_loss = InstanceLoss() # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr, config.epochs, len(self.train_loader), config.lr_step, config.warmup_epochs) self.summary = TensorboardSummary('./train_log') # labels for adversarial training self.source_label = 0 self.target_label = 1 # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model) patch_replication_callback(self.model) # cudnn.benchmark = True self.model = self.model.cuda() self.D = torch.nn.DataParallel(self.D) patch_replication_callback(self.D) self.D = self.D.cuda() self.best_pred_source = 0.0 self.best_pred_target = 0.0 # Resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) if args.cuda: self.model.module.load_state_dict(checkpoint) else: self.model.load_state_dict(checkpoint, map_location=torch.device('cpu')) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, args.start_epoch)) def training(self, epoch): train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum, adv_loss_sum, d_loss_sum, ins_loss_sum = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 self.model.train() if config.freeze_bn: self.model.module.freeze_bn() tbar = tqdm(self.train_loader) num_img_tr = len(self.train_loader) target_train_iterator = iter(self.target_train_loader) for i, sample in enumerate(tbar): itr = epoch * len(self.train_loader) + i #if self.visdom: # self.vis.line(X=torch.tensor([itr]), Y=torch.tensor([self.optimizer.param_groups[0]['lr']]), # win='lr', opts=dict(title='lr', xlabel='iter', ylabel='lr'), # update='append' if itr>0 else None) self.summary.writer.add_scalar( 'Train/lr', self.optimizer.param_groups[0]['lr'], itr) A_image, A_target = sample['image'], sample['label'] # Get one batch from target domain try: target_sample = next(target_train_iterator) except StopIteration: target_train_iterator = iter(self.target_train_loader) target_sample = next(target_train_iterator) B_image, B_target, B_image_pair = target_sample[ 'image'], target_sample['label'], target_sample['image_pair'] if self.args.cuda: A_image, A_target = A_image.cuda(), A_target.cuda() B_image, B_target, B_image_pair = B_image.cuda( ), B_target.cuda(), B_image_pair.cuda() self.scheduler(self.optimizer, i, epoch, self.best_pred_source, self.best_pred_target, self.config.lr_ratio) self.scheduler(self.D_optimizer, i, epoch, self.best_pred_source, self.best_pred_target, self.config.lr_ratio) A_output, A_feat, A_low_feat = self.model(A_image) B_output, B_feat, B_low_feat = self.model(B_image) #B_output_pair, B_feat_pair, B_low_feat_pair = self.model(B_image_pair) #B_output_pair, B_feat_pair, B_low_feat_pair = flip(B_output_pair, dim=-1), flip(B_feat_pair, dim=-1), flip(B_low_feat_pair, dim=-1) self.optimizer.zero_grad() self.D_optimizer.zero_grad() # Train seg network for param in self.D.parameters(): param.requires_grad = False # Supervised loss seg_loss = self.criterion(A_output, A_target) main_loss = seg_loss # Unsupervised loss #ins_loss = 0.01 * self.instance_loss(B_output, B_output_pair) #main_loss += ins_loss # Train adversarial loss D_out = self.D(prob_2_entropy(F.softmax(B_output))) adv_loss = bce_loss(D_out, self.source_label) main_loss += self.config.lambda_adv * adv_loss main_loss.backward() # Train discriminator for param in self.D.parameters(): param.requires_grad = True A_output_detach = A_output.detach() B_output_detach = B_output.detach() # source D_source = self.D(prob_2_entropy(F.softmax(A_output_detach))) source_loss = bce_loss(D_source, self.source_label) source_loss = source_loss / 2 # target D_target = self.D(prob_2_entropy(F.softmax(B_output_detach))) target_loss = bce_loss(D_target, self.target_label) target_loss = target_loss / 2 d_loss = source_loss + target_loss d_loss.backward() self.optimizer.step() self.D_optimizer.step() seg_loss_sum += seg_loss.item() #ins_loss_sum += ins_loss.item() adv_loss_sum += self.config.lambda_adv * adv_loss.item() d_loss_sum += d_loss.item() #train_loss += seg_loss.item() + self.config.lambda_adv * adv_loss.item() train_loss += seg_loss.item() self.summary.writer.add_scalar('Train/SegLoss', seg_loss.item(), itr) #self.summary.writer.add_scalar('Train/InsLoss', ins_loss.item(), itr) self.summary.writer.add_scalar('Train/AdvLoss', adv_loss.item(), itr) self.summary.writer.add_scalar('Train/DiscriminatorLoss', d_loss.item(), itr) tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) # Show the results of the last iteration #if i == len(self.train_loader)-1: print("Add Train images at epoch" + str(epoch)) self.summary.visualize_image('Train-Source', self.config.dataset, A_image, A_target, A_output, epoch, 5) self.summary.visualize_image('Train-Target', self.config.target, B_image, B_target, B_output, epoch, 5) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + A_image.data.shape[0])) print('Loss: %.3f' % train_loss) #print('Seg Loss: %.3f' % seg_loss_sum) #print('Ins Loss: %.3f' % ins_loss_sum) #print('BN Loss: %.3f' % bn_loss_sum) #print('Adv Loss: %.3f' % adv_loss_sum) #print('Discriminator Loss: %.3f' % d_loss_sum) #if self.visdom: #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([seg_loss_sum]), win='train_loss', name='Seg_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([ins_loss_sum]), win='train_loss', name='Ins_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([bn_loss_sum]), win='train_loss', name='BN_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss_sum]), win='train_loss', name='Adv_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([d_loss_sum]), win='train_loss', name='Dis_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) def validation(self, epoch): def get_metrics(tbar, if_source=False): self.evaluator.reset() test_loss = 0.0 #feat_mean, low_feat_mean, feat_var, low_feat_var = 0, 0, 0, 0 #adv_loss = 0.0 for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output, low_feat, feat = self.model(image) #low_feat = low_feat.cpu().numpy() #feat = feat.cpu().numpy() #if isinstance(feat, np.ndarray): # feat_mean += feat.mean(axis=0).mean(axis=1).mean(axis=1) # low_feat_mean += low_feat.mean(axis=0).mean(axis=1).mean(axis=1) # feat_var += feat.var(axis=0).var(axis=1).var(axis=1) # low_feat_var += low_feat.var(axis=0).var(axis=1).var(axis=1) #else: # feat_mean = feat.mean(axis=0).mean(axis=1).mean(axis=1) # low_feat_mean = low_feat.mean(axis=0).mean(axis=1).mean(axis=1) # feat_var = feat.var(axis=0).var(axis=1).var(axis=1) # low_feat_var = low_feat.var(axis=0).var(axis=1).var(axis=1) #d_output = self.D(prob_2_entropy(F.softmax(output))) #adv_loss += bce_loss(d_output, self.source_label).item() loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target_ = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target_, pred) if if_source: print("Add Validation-Source images at epoch" + str(epoch)) self.summary.visualize_image('Val-Source', self.config.dataset, image, target, output, epoch, 5) else: print("Add Validation-Target images at epoch" + str(epoch)) self.summary.visualize_image('Val-Target', self.config.target, image, target, output, epoch, 5) #feat_mean /= (i+1) #low_feat_mean /= (i+1) #feat_var /= (i+1) #low_feat_var /= (i+1) #adv_loss /= (i+1) # Fast test during the training Acc = self.evaluator.Building_Acc() IoU = self.evaluator.Building_IoU() mIoU = self.evaluator.Mean_Intersection_over_Union() if if_source: print('Validation on source:') else: print('Validation on target:') print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + image.data.shape[0])) print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU)) print('Loss: %.3f' % test_loss) if if_source: names = ['source', 'source_acc', 'source_IoU', 'source_mIoU'] self.summary.writer.add_scalar('Val/SourceAcc', Acc, epoch) self.summary.writer.add_scalar('Val/SourceIoU', IoU, epoch) else: names = ['target', 'target_acc', 'target_IoU', 'target_mIoU'] self.summary.writer.add_scalar('Val/TargetAcc', Acc, epoch) self.summary.writer.add_scalar('Val/TargetIoU', IoU, epoch) # Draw Visdom #if if_source: # names = ['source', 'source_acc', 'source_IoU', 'source_mIoU'] #else: # names = ['target', 'target_acc', 'target_IoU', 'target_mIoU'] #if self.visdom: # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([test_loss]), win='val_loss', name=names[0], # update='append') # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss]), win='val_loss', name='adv_loss', # update='append') # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([Acc]), win='metrics', name=names[1], # opts=dict(title='metrics', xlabel='epoch', ylabel='performance'), # update='append' if epoch > 0 else None) # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([IoU]), win='metrics', name=names[2], # update='append') # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([mIoU]), win='metrics', name=names[3], # update='append') return Acc, IoU, mIoU self.model.eval() tbar_source = tqdm(self.val_loader, desc='\r') tbar_target = tqdm(self.target_val_loader, desc='\r') s_acc, s_iou, s_miou = get_metrics(tbar_source, True) t_acc, t_iou, t_miou = get_metrics(tbar_target, False) new_pred_source = s_iou new_pred_target = t_iou if new_pred_source > self.best_pred_source or new_pred_target > self.best_pred_target: is_best = True self.best_pred_source = max(new_pred_source, self.best_pred_source) self.best_pred_target = max(new_pred_target, self.best_pred_target) print('Saving state, epoch:', epoch) torch.save( self.model.module.state_dict(), self.args.save_folder + 'models/' + 'epoch' + str(epoch) + '.pth') loss_file = { 's_Acc': s_acc, 's_IoU': s_iou, 's_mIoU': s_miou, 't_Acc': t_acc, 't_IoU': t_iou, 't_mIoU': t_miou } with open( os.path.join(self.args.save_folder, 'eval', 'epoch' + str(epoch) + '.json'), 'w') as f: json.dump(loss_file, f)
# build graph feature_extractor = Inceptionv2() dis = Discriminator() if gv2_model_path is not None: fea_dict = torch.load(gv2_model_path) # fea_dict.pop('classifier.weight') # fea_dict.pop('classifier.bias') # fea_dict.pop('criterion2.center_feature') # fea_dict.pop('criterion2.all_labels') feature_extractor.load_state_dict(fea_dict) if dis_model_path is not None: dis.load_state_dict(torch.load(dis_model_path)) if is_cuda: feature_extractor.cuda() dis.cuda() # input pipeline data_iter = DataProvider(batch_size, is_cuda=is_cuda) # summary writer if log_path: writer = SummaryWriter(log_path, 'comment test') else: writer = None # opt opt_d = Adam(dis.parameters()) opt_fea = Adam(feature_extractor.parameters())
batch_size = 64 # epoch数の定義 epoch_num = 10 # dataloaderの準備 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # モデル定義 model_G = Generator() model_D = Discriminator() cuda = torch.cuda.is_available() if cuda: model_G.cuda() model_D.cuda() print('cuda is available!') else: print('cuda is not available') # パラメータ設定 # params_G = optim.Adam(model_G.parameters(), # lr=0.0002, betas=(0.5, 0.999)) # params_D = optim.Adam(model_D.parameters(), # lr=0.0002, betas=(0.5, 0.999)) params_G = optim.Adam(model_G.parameters(), lr=0.01) params_D = optim.Adam(model_D.parameters(), lr=0.01)
generator.load_state_dict(t.load('saved_models/trained_generator_' + args.model_name)) discriminator.load_state_dict(t.load('saved_models/trained_discriminator_' + args.model_name)) ce_result_valid = list(np.load('logs/{}/ce_result_valid.npy'.format(args.model_name))) ce_result_train = list(np.load('logs/{}/ce_result_train.npy'.format(args.model_name))) ce2_result_train = list(np.load('logs/{}/ce2_result_train.npy'.format(args.model_name))) ce2_result_valid = list(np.load('logs/{}/ce2_result_valid.npy'.format(args.model_name))) kld_result_valid = list(np.load('logs/{}/kld_result_valid.npy'.format(args.model_name))) kld_result_train = list(np.load('logs/{}/kld_result_train.npy'.format(args.model_name))) dg_result_train = list(np.load('logs/{}/dg_result_train.npy'.format(args.model_name))) dg_result_valid = list(np.load('logs/{}/dg_result_valid.npy'.format(args.model_name))) d_result_train = list(np.load('logs/{}/d_result_train.npy'.format(args.model_name))) d_result_valid = list(np.load('logs/{}/d_result_valid.npy'.format(args.model_name))) if args.use_cuda: generator = generator.cuda() discriminator = discriminator.cuda() g_optim = Adam(generator.learnable_parameters(), args.learning_rate) d_optim = Adam(discriminator.learnable_parameters(), args.learning_rate) # [generator, discriminator], [g_optim, d_optim] = amp.initialize([generator, discriminator], [g_optim, d_optim], opt_level="O1", num_losses=2) rollout = Rollout(generator, discriminator, 0.8, rollout_num) # discriminator, d_optim = amp.initialize(discriminator, d_optim, opt_level="O1") scaler = amp.GradScaler() train_step = trainer(generator, g_optim, discriminator, d_optim, rollout, batch_loader, scaler) validate = validater(generator, discriminator, rollout, batch_loader) # converge_criterion, converge_count = 1000, 0
def main(): """Main function that trains and/or evaluates a model.""" params = interpret_args() if params.gan: assert params.max_gen_len == params.train_maximum_sql_length \ == params.eval_maximum_sql_length data = atis_data.ATISDataset(params) generator = SchemaInteractionATISModel(params, data.input_vocabulary, data.output_vocabulary, data.output_vocabulary_schema, None) generator = generator.cuda() generator.build_optim() if params.gen_from_ckp: gen_ckp_path = os.path.join(params.logdir, params.gen_pretrain_ckp) if params.fine_tune_bert: gen_epoch, generator, generator.trainer, \ generator.bert_trainer = \ load_ckp( gen_ckp_path, generator, generator.trainer, generator.bert_trainer ) else: gen_epoch, generator, generator.trainer, _ = \ load_ckp( gen_ckp_path, generator, generator.trainer ) else: gen_epoch = 0 print('====================Model Parameters====================') print('=======================Generator========================') for name, param in generator.named_parameters(): print(name, param.requires_grad, param.is_cuda, param.size()) assert param.is_cuda print('==================Optimizer Parameters==================') print('=======================Generator========================') for param_group in generator.trainer.param_groups: print(param_group.keys()) for param in param_group['params']: print(param.size()) if params.fine_tune_bert: print('=========================BERT===========================') for param_group in generator.bert_trainer.param_groups: print(param_group.keys()) for param in param_group['params']: print(param.size()) sys.stdout.flush() # Pre-train generator with MLE if params.train: print('=============== Pre-training generator! ================') train(generator, data, params, gen_epoch) print('=========== Pre-training generator complete! ===========') dis_filter_sizes = [i for i in range(1, params.max_gen_len, 4)] dis_num_filters = [(100 + i * 10) for i in range(1, params.max_gen_len, 4)] discriminator = Discriminator(params, data.dis_src_vocab, data.dis_tgt_vocab, params.max_gen_len, params.num_dis_classes, dis_filter_sizes, dis_num_filters, params.max_pos_emb, params.num_tok_type, params.dis_dropout) discriminator = discriminator.cuda() dis_criterion = nn.NLLLoss(reduction='mean') dis_criterion = dis_criterion.cuda() dis_optimizer = optim.Adam(discriminator.parameters()) if params.dis_from_ckp: dis_ckp_path = os.path.join(params.logdir, params.dis_pretrain_ckp) dis_epoch, discriminator, dis_optimizer, _ = load_ckp( dis_ckp_path, discriminator, dis_optimizer) else: dis_epoch = 0 print('====================Model Parameters====================') print('=====================Discriminator======================') for name, param in discriminator.named_parameters(): print(name, param.requires_grad, param.is_cuda, param.size()) assert param.is_cuda print('==================Optimizer Parameters==================') print('=====================Discriminator======================') for param_group in dis_optimizer.param_groups: print(param_group.keys()) for param in param_group['params']: print(param.size()) sys.stdout.flush() # Pre-train discriminator if params.pretrain_discriminator: print('============= Pre-training discriminator! ==============') pretrain_discriminator(params, generator, discriminator, dis_criterion, dis_optimizer, data, start_epoch=dis_epoch) print('========= Pre-training discriminator complete! =========') # Adversarial Training if params.adversarial_training: print('================ Adversarial training! =================') generator.build_optim() dis_criterion = nn.NLLLoss(reduction='mean') dis_optimizer = optim.Adam(discriminator.parameters()) dis_criterion = dis_criterion.cuda() if params.adv_from_ckp and params.mle is not "mixed_mle": adv_ckp_path = os.path.join(params.logdir, params.adv_ckp) if params.fine_tune_bert: epoch, batches, pos_in_batch, generator, discriminator, \ generator.trainer, dis_optimizer, \ generator.bert_trainer, _, _ = \ load_adv_ckp( adv_ckp_path, generator, discriminator, generator.trainer, dis_optimizer, generator.bert_trainer) else: epoch, batches, pos_in_batch, generator, discriminator, \ generator.trainer, dis_optimizer, _, _, _ = \ load_adv_ckp( adv_ckp_path, generator, discriminator, generator.trainer, dis_optimizer) adv_train(generator, discriminator, dis_criterion, dis_optimizer, data, params, start_epoch=epoch, start_batches=batches, start_pos_in_batch=pos_in_batch) elif params.adv_from_ckp and params.mle == "mixed_mle": adv_ckp_path = os.path.join(params.logdir, params.adv_ckp) if params.fine_tune_bert: epoch, batches, pos_in_batch, generator, discriminator, \ generator.trainer, dis_optimizer, \ generator.bert_trainer, clamp, length = \ load_adv_ckp( adv_ckp_path, generator, discriminator, generator.trainer, dis_optimizer, generator.bert_trainer, mle=True) else: epoch, batches, pos_in_batch, generator, discriminator, \ generator.trainer, dis_optimizer, _, clamp, length = \ load_adv_ckp( adv_ckp_path, generator, discriminator, generator.trainer, dis_optimizer, mle=True) mixed_mle(generator, discriminator, dis_criterion, dis_optimizer, data, params, start_epoch=epoch, start_batches=batches, start_pos_in_batch=pos_in_batch, start_clamp=clamp, start_len=length) else: if params.mle == 'mixed_mle': mixed_mle(generator, discriminator, dis_criterion, dis_optimizer, data, params) else: adv_train(generator, discriminator, dis_criterion, dis_optimizer, data, params) if params.evaluate and 'valid' in params.evaluate_split: print("================== Evaluating! ===================") evaluate(generator, data, params, split='valid') print("============= Evaluation finished! ===============")