print(data_new.size()) dataset = Data.TensorDataset(data_tensor=data_new, target_tensor=train_y) loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=opt.batchSize, shuffle=True) ############### MODEL #################### ndf = opt.ndf ngf = opt.ngf nc = 1 netD = Discriminator(nc, ndf) netG = Generator(nc, ngf, opt.nz) #if(opt.cuda): netD.cuda() netG.cuda() ########### LOSS & OPTIMIZER ########## criterion = nn.BCELoss() optimizerD = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) ########## GLOBAL VARIABLES ########### #noise_all = torch.FloatTensor(20,opt.nz,1,1) noise = torch.FloatTensor(opt.batchSize, opt.nz, 1, 1) real = torch.FloatTensor(opt.batchSize, nc, opt.imageSize, opt.imageSize)
test_folder = 'output_' + opt.output_str ###### Definition of variables ###### # Networks model = model(fea_channel=opt.fea_channel) discriminator = Discriminator() if opt.load_model: load_path = opt.load_path model.load_state_dict(torch.load(load_path)) print('model loaded') torch.cuda.empty_cache() model = model.cuda() discriminator = discriminator.cuda() vgg_model = models.vgg16(pretrained=True) vgg_model.cuda() loss_network = utils.LossNetwork(vgg_model) loss_network.eval() # Optimizers & LR schedulers optimizer_G = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(0.5, 0.999)) # Dataset loader trainloader = loaddata.getTrainingData(opt.batchSize, size=128)
batchSize = 1 cuda = 1 D_A = Discriminator(input_nc,ndf) D_B = Discriminator(output_nc,ndf) G_AB = Generator(input_nc, output_nc, ngf) G_BA = Generator(output_nc, input_nc, ngf) G_AB.apply(weights_init) G_BA.apply(weights_init) D_A.apply(weights_init) D_B.apply(weights_init) if(cuda): D_A.cuda() D_B.cuda() G_AB.cuda() G_BA.cuda() ########### LOSS & OPTIMIZER ########## criterionMSE = nn.L1Loss() criterion = nn.MSELoss() # chain is used to update two generators simultaneously optimizerD_A = torch.optim.Adam(D_A.parameters(),lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4) optimizerD_B = torch.optim.Adam(D_B.parameters(),lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4) optimizerG = torch.optim.Adam(chain(G_AB.parameters(),G_BA.parameters()),lr=0.0002, betas=(0.5, 0.999)) real_A = torch.FloatTensor(batchSize, input_nc, fineSize, fineSize) AB = torch.FloatTensor(batchSize, input_nc, fineSize, fineSize)
def main(args): #=========================================================================== # Set the file name format FILE_NAME_FORMAT = "{0}_{1}_{2:d}_{3:d}_{4:d}_{5:f}{6}".format( args.model, args.dataset, args.epochs, args.obj_step, args.batch_size, args.lr, args.flag) # Set the results file path RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl' RESULT_FILE_PATH = os.path.join(RESULTS_PATH, RESULT_FILE_NAME) # Set the checkpoint file path CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt' CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME) BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt' BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, BEST_CHECKPOINT_FILE_NAME) # Set the random seed same for reproducibility random.seed(190811) torch.manual_seed(190811) torch.cuda.manual_seed_all(190811) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Step1 ==================================================================== # Load dataset if args.dataset == 'CelebA': dataloader = CelebA_Dataloader() else: assert False, "Please select the proper dataset." train_loader = dataloader.get_train_loader(batch_size=args.batch_size, num_workers=args.num_workers) print('==> DataLoader ready.') # Step2 ==================================================================== # Make the model if args.model in ['WGAN', 'DCGAN']: generator = Generator(BN=True) discriminator = Discriminator(BN=True) elif args.model in ['WGAN_noBN', 'DCGAN_noBN']: generator = Generator(BN=False) discriminator = Discriminator(BN=False) else: assert False, "Please select the proper model." # Check DataParallel available if torch.cuda.device_count() > 1: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) # Check CUDA available if torch.cuda.is_available(): generator.cuda() discriminator.cuda() print('==> Model ready.') # Step3 ==================================================================== # Set loss function and optimizer if args.model in ['DCGAN', 'DCGAN_noBN']: criterion = nn.BCELoss() else: criterion = None optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=args.lr) optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=args.lr) step_counter = StepCounter(args.obj_step) print('==> Criterion and optimizer ready.') # Step4 ==================================================================== # Train and validate the model start_epoch = 0 best_metric = float("inf") validate_noise = torch.randn(args.batch_size, 100, 1, 1) # Initialize the result lists train_loss_G = [] train_loss_D = [] train_distance = [] if args.resume: assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!' checkpoint = torch.load(CHECKPOINT_FILE_PATH) generator.load_state_dict(checkpoint['generator_state_dict']) discriminator.load_state_dict(checkpoint['discriminator_state_dict']) optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict']) start_epoch = checkpoint['epoch'] step_counter.current_step = checkpoint['current_step'] train_loss_G = checkpoint['train_loss_G'] train_loss_D = checkpoint['train_loss_D'] train_distance = checkpoint['train_distance'] best_metric = checkpoint['best_metric'] # Save the training information result_data = {} result_data['model'] = args.model result_data['dataset'] = args.dataset result_data['target_epoch'] = args.epochs result_data['batch_size'] = args.batch_size # Check the directory of the file path if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)): os.makedirs(os.path.dirname(RESULT_FILE_PATH)) if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)): os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH)) print('==> Train ready.') # Validate before training (step 0) val(generator, validate_noise, step_counter, FILE_NAME_FORMAT) for epoch in range(args.epochs): # strat after the checkpoint epoch if epoch < start_epoch: continue print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs)) epoch_time = time.time() #======================================================================= # train the model (+ validate the model) tloss_G, tloss_D, tdist = train(generator, discriminator, train_loader, criterion, optimizer_G, optimizer_D, args.clipping, args.num_critic, step_counter, validate_noise, FILE_NAME_FORMAT) train_loss_G.extend(tloss_G) train_loss_D.extend(tloss_D) train_distance.extend(tdist) #======================================================================= current = time.time() # Calculate average loss avg_loss_G = sum(tloss_G) / len(tloss_G) avg_loss_D = sum(tloss_D) / len(tloss_D) avg_distance = sum(tdist) / len(tdist) # Save the current result result_data['current_epoch'] = epoch result_data['train_loss_G'] = train_loss_G result_data['train_loss_D'] = train_loss_D result_data['train_distance'] = train_distance # Save result_data as pkl file with open(RESULT_FILE_PATH, 'wb') as pkl_file: pickle.dump(result_data, pkl_file, protocol=pickle.HIGHEST_PROTOCOL) # Save the best checkpoint # if avg_distance < best_metric: # best_metric = avg_distance # torch.save({ # 'epoch': epoch+1, # 'generator_state_dict': generator.state_dict(), # 'discriminator_state_dict': discriminator.state_dict(), # 'optimizer_G_state_dict': optimizer_G.state_dict(), # 'optimizer_D_state_dict': optimizer_D.state_dict(), # 'current_step': step_counter.current_step, # 'best_metric': best_metric, # }, BEST_CHECKPOINT_FILE_PATH) # Save the current checkpoint torch.save( { 'epoch': epoch + 1, 'generator_state_dict': generator.state_dict(), 'discriminator_state_dict': discriminator.state_dict(), 'optimizer_G_state_dict': optimizer_G.state_dict(), 'optimizer_D_state_dict': optimizer_D.state_dict(), 'current_step': step_counter.current_step, 'train_loss_G': train_loss_G, 'train_loss_D': train_loss_D, 'train_distance': train_distance, 'best_metric': best_metric, }, CHECKPOINT_FILE_PATH) # Print the information on the console print("model : {}".format(args.model)) print("dataset : {}".format(args.dataset)) print("batch_size : {}".format(args.batch_size)) print("current step : {:d}".format(step_counter.current_step)) print("current lrate : {:f}".format(args.lr)) print("gen/disc loss : {:f}/{:f}".format( avg_loss_G, avg_loss_D)) print("distance metric : {:f}".format(avg_distance)) print("epoch time : {0:.3f} sec".format(current - epoch_time)) print("Current elapsed time : {0:.3f} sec".format(current - start)) # If iteration step has been satisfied if step_counter.exit_signal: break print('==> Train done.') print(' '.join(['Results have been saved at', RESULT_FILE_PATH])) print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
def train(**kwargs): opt._parse(kwargs) id_file_dir = 'ImageSets/Main/trainval_big_64.txt' img_dir = 'JPEGImages' anno_dir = 'AnnotationsBig' large_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir) dataloader_large = data_.DataLoader(large_dataset, \ batch_size=1, \ shuffle=True, \ # pin_memory=True, num_workers=opt.num_workers) id_file_dir = 'ImageSets/Main/trainval_pcgan_generated_small.txt' img_dir = 'JPEGImagesPCGANGenerated' anno_dir = 'AnnotationsPCGANGenerated' small_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir) dataloader_small = data_.DataLoader(small_dataset, \ batch_size=1, \ shuffle=True, \ # pin_memory=True, num_workers=opt.num_workers) small_test_dataset = SmallImageTestDataset(opt) dataloader_small_test = data_.DataLoader(small_test_dataset, \ batch_size=1, \ shuffle=True, \ pin_memory=True, num_workers=opt.test_num_workers) print('{:d} roidb large entries'.format(len(dataloader_large))) print('{:d} roidb small entries'.format(len(dataloader_small))) print('{:d} roidb small test entries'.format(len(dataloader_small_test))) faster_rcnn = FasterRCNNVGG16_GAN() faster_rcnn_ = FasterRCNNVGG16() print('model construct completed') trainer_ = FasterRCNNTrainer(faster_rcnn_).cuda() netD = Discriminator() netD.apply(weights_init) faster_rcnn_.cuda() netD.cuda() lr = opt.LEARNING_RATE params_D = [] for key, value in dict(netD.named_parameters()).items(): if value.requires_grad: if 'bias' in key: params_D += [{'params': [value], 'lr': lr * 2, \ 'weight_decay': 0}] else: params_D += [{'params': [value], 'lr': lr, 'weight_decay': opt.weight_decay}] optimizerD = optim.SGD(params_D, momentum=0.9) # optimizerG = optim.Adam(faster_rcnn.parameters(), lr=lr, betas=(0.5, 0.999)) if not opt.gan_load_path: trainer_.load(opt.load_path) print('load pretrained faster rcnn model from %s' % opt.load_path) # optimizer_ = trainer_.optimizer state_dict_ = faster_rcnn_.state_dict() state_dict = faster_rcnn.state_dict() # for k, i in state_dict_.items(): # icpu = i.cpu() # b = icpu.data.numpy() # sz = icpu.data.numpy().shape # state_dict[k] = state_dict_[k] state_dict.update(state_dict_) faster_rcnn.load_state_dict(state_dict) faster_rcnn.cuda() trainer = FasterRCNNTrainer(faster_rcnn).cuda() if opt.gan_load_path: trainer.load(opt.gan_load_path, load_optimizer=True) print('load pretrained generator model from %s' % opt.gan_load_path) if opt.disc_load_path: state_dict_d = torch.load(opt.disc_load_path) netD.load_state_dict(state_dict_d['model']) optimizerD.load_state_dict(state_dict_d['optimizer']) print('load pretrained discriminator model from %s' % opt.disc_load_path) real_label = 1 fake_label = 0 # rpn_loc_loss = [] # rpn_cls_loss = [] # roi_loc_loss = [] # roi_cls_loss = [] # total_loss = [] test_map_list = [] criterion = nn.BCELoss() iters_per_epoch = min(len(dataloader_large), len(dataloader_small)) best_map = 0 device = torch.device("cuda:2" if (torch.cuda.is_available()) else "cpu") for epoch in range(1, opt.gan_epoch + 1): trainer.reset_meters() loss_temp_G = 0 loss_temp_D = 0 if epoch % (opt.lr_decay_step + 1) == 0: adjust_learning_rate(trainer.optimizer, opt.LEARNING_RATE_DECAY_GAMMA) adjust_learning_rate(optimizerD, opt.LEARNING_RATE_DECAY_GAMMA) lr *= opt.LEARNING_RATE_DECAY_GAMMA data_iter_large = iter(dataloader_large) data_iter_small = iter(dataloader_small) for step in tqdm(range(iters_per_epoch)): #####(1) Update Perceptual branch + generator(zero mapping) #### Discriminator network: maximize log(D(x))+ log(1-D(G(z))) ##### Train with all_real batch ##### Format batch netD.zero_grad() data_large = next(data_iter_large) img, bbox_, label_, scale_ = data_large scale = at.scalar(scale_) img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda() ##### Forward pass real batch through D # faster_rcnn.zero_grad() # trainer.optimizer.zero_grad() # trainer.optimizer.zero_grad() losses, pooled_feat, rois_label, conv1_feat = trainer.train_step_gan(img, bbox, label, scale) # if step < 1: # custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/large_orig_%s' % str(epoch)) # custom_viz(pooled_feat.cpu().detach(), 'results-gan/features/large_scaled_%s' % str(epoch)) keep = rois_label != 0 pooled_feat = pooled_feat[keep] real_b_size = pooled_feat.size(0) real_labels = torch.full((real_b_size,), real_label, device=device) output = netD(pooled_feat.detach()).view(-1) # print(output) ##### Calculate loss on all-real batch errD_real = criterion(output, real_labels) errD_real.backward() D_x = output.mean().item() ##### Train with all_fake batch # Generate batch of fake images with G data_small = next(data_iter_small) img, bbox_, label_, scale_ = data_small scale = at.scalar(scale_) img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda() trainer.optimizer.zero_grad() losses, fake_pooled_feat, rois_label, conv1_feat = trainer.train_step_gan_second(img, bbox, label, scale) # if step < 1: # custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/small_orig_%s' % str(epoch)) # custom_viz(fake_pooled_feat.cpu().detach(), 'results-gan/features/small_scaled_%s' % str(epoch)) # select fg rois keep = rois_label != 0 fake_pooled_feat = fake_pooled_feat[keep] # print(fake_pooled_feat) # print(torch.nonzero(torch.isnan(fake_pooled_feat.view(-1)))) fake_b_size = fake_pooled_feat.size(0) fake_labels = torch.full((fake_b_size,), fake_label, device=device) # optimizerD.zero_grad() output = netD(fake_pooled_feat.detach()).view(-1) # calculate D's loss on the all_fake batch errD_fake = criterion(output, fake_labels) errD_fake.backward(retain_graph=True) D_G_Z1 = output.mean().item() # add the gradients from the all-real and all-fake batches errD = errD_fake + errD_real # Update D optimizerD.step() ################################################ #####(2) Update G network: maximize log(D(G(z))) ################################################ faster_rcnn.zero_grad() fake_labels.fill_(real_label) output = netD(fake_pooled_feat).view(-1) # calculate gradients for G errG = criterion(output, fake_labels) errG += losses.total_loss errG.backward() D_G_Z2 = output.mean().item() clip_gradient(faster_rcnn, 10.) trainer.optimizer.step() loss_temp_G += errG.item() loss_temp_D += errD.item() if step % opt.plot_every == 0: if step > 0: loss_temp_G /= (opt.plot_every + 1) loss_temp_D /= (opt.plot_every + 1) # losses_dict = trainer.get_meter_data() # # rpn_loc_loss.append(losses_dict['rpn_loc_loss']) # roi_loc_loss.append(losses_dict['roi_loc_loss']) # rpn_cls_loss.append(losses_dict['rpn_cls_loss']) # roi_cls_loss.append(losses_dict['roi_cls_loss']) # total_loss.append(losses_dict['total_loss']) # # save_losses('rpn_loc_loss', rpn_loc_loss, epoch) # save_losses('roi_loc_loss', roi_loc_loss, epoch) # save_losses('rpn_cls_loss', rpn_cls_loss, epoch) # save_losses('total_loss', total_loss, epoch) # save_losses('roi_cls_loss', roi_cls_loss, epoch) print("[epoch %2d] lossG: %.4f lossD: %.4f, lr: %.2e" % (epoch, loss_temp_G, loss_temp_D, lr)) print("\t\t\trcnn_cls: %.4f, rcnn_box %.4f" % (losses.roi_cls_loss, losses.roi_loc_loss)) print("\t\t\trpn_cls: %.4f, rpn_box %.4f" % (losses.rpn_cls_loss, losses.rpn_loc_loss)) print('\t\t\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (D_x, D_G_Z1, D_G_Z2)) loss_temp_D = 0 loss_temp_G = 0 eval_result = eval(dataloader_small_test, faster_rcnn, test_num=opt.test_num) test_map_list.append(eval_result['map']) save_map(test_map_list, epoch) lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr'] log_info = 'lr:{}, map:{}'.format(str(lr_), str(eval_result['map'])) print(log_info) if eval_result['map'] > best_map: best_map = eval_result['map'] timestr = time.strftime('%m%d%H%M') trainer.save(best_map=best_map, save_path='checkpoints-pcgan-generated/gan_fasterrcnn_%s' % timestr) save_dict = dict() save_dict['model'] = netD.state_dict() save_dict['optimizer'] = optimizerD.state_dict() save_path = 'checkpoints-pcgan-generated/discriminator_%s' % timestr torch.save(save_dict, save_path)