class DamageDetection: def __init__(self, args): self.args = args torch.manual_seed(self.args.seed) np.random.seed(self.args.seed) print('{} detection...'.format(args.dataset)) white_noise = dp.DatasetReader(white_noise=self.args.dataset, data_path=data_path, len_seg=self.args.len_seg ) self.testset = torch.tensor(torch.from_numpy(white_noise.dataset_), dtype=torch.float32) self.spots = np.load('{}/spots.npy'.format(info_path)) self.Generator = Generator(args) # Generator self.Discriminator = Discriminator(args) # Discriminator def __call__(self, *args, **kwargs): self.test() def file_name(self): return '{}_{}_{}_{}_{}_{}'.format(self.args.model_name, self.args.net_name, self.args.len_seg, self.args.optimizer, self.args.learning_rate, self.args.num_epoch ) def test(self): path_gen = '{}/models/{}_Gen.model'.format(save_path, self.file_name()) path_dis = '{}/models/{}_Dis.model'.format(save_path, self.file_name()) self.Generator.load_state_dict(torch.load(path_gen)) # Load Generator self.Discriminator.load_state_dict(torch.load(path_dis)) # Load Discriminator self.Generator.eval() self.Discriminator.eval() damage_indices = {} beta = 0.5 with torch.no_grad(): for i, spot in enumerate(self.spots): damage_indices[spot] = {} z = torch.randn(self.testset.shape[1], 50) data_gen = self.Generator(z) data_real = self.testset[i] res = ((data_gen - data_real) ** 2).mean() dis = self.Discriminator(data_gen).mean() - 1 loss = beta * res.item() + (1 - beta) * np.abs(dis.item()) damage_indices[spot]['Generate residual'] = res.item() damage_indices[spot]['Discriminate loss'] = np.abs(dis.item()) damage_indices[spot]['Loss'] = loss print('[{}]\tGenerate residual: {:5f}\tDiscriminate loss: {:5f}\tLoss: {:5f}'. format(spot, res.item(), np.abs(dis.item()), loss) ) damage_indices = json.dumps(damage_indices, indent=2) with open('{}/damage index/{}_{}.json'.format(save_path, self.args.dataset, self.file_name() ), 'w') as f: f.write(damage_indices)
new_state_dict[name] = v return new_state_dict #torch.manual_seed(44) os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" device = "cuda" if torch.cuda.is_available() else "cpu" G = Generator(z_dim=20, image_size=64) D = Discriminator(z_dim=20, image_size=64) '''-------load weights-------''' G_load_weights = torch.load('./checkpoints/G_AnoGAN_300.pth') G.load_state_dict(fix_model_state_dict(G_load_weights)) D_load_weights = torch.load('./checkpoints/D_AnoGAN_300.pth') D.load_state_dict(fix_model_state_dict(D_load_weights)) G.to(device) D.to(device) """use GPU in parallel""" if device == 'cuda': G = torch.nn.DataParallel(G) D = torch.nn.DataParallel(D) print("parallel mode") batch_size = 8 z_dim = 20 fixed_z = torch.randn(batch_size, z_dim) fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1) fake_images = G(fixed_z.to(device))
class Trainer(object): def __init__(self, style_data_loader, content_data_loader, config): self.log_file = os.path.join(config.log_path, config.version,config.version+"_log.log") self.report_file = os.path.join(config.log_path, config.version,config.version+"_report.log") logging.basicConfig(filename=self.report_file, format='[%(asctime)s-%(levelname)s:%(message)s]', level = logging.DEBUG,filemode='w', datefmt='%Y-%m-%d%I:%M:%S %p') self.Experiment_description = config.experiment_description logging.info("Experiment description: \n%s"%self.Experiment_description) # Data loader self.style_data_loader = style_data_loader self.content_data_loader = content_data_loader # exact loss self.adv_loss = config.adv_loss logging.info("loss: %s"%self.adv_loss) # Model hyper-parameters self.imsize = config.imsize logging.info("image size: %d"%self.imsize) self.batch_size = config.batch_size logging.info("Batch size: %d"%self.batch_size) logging.info("Is shuffle: {}".format(config.is_shuffle)) logging.info("Image center crop size: {}".format(config.center_crop)) self.res_num = config.res_num logging.info("resblock number: %d"%self.res_num) self.g_conv_dim = config.g_conv_dim logging.info("generator convolution initial channel: %d"%self.g_conv_dim) self.d_conv_dim = config.d_conv_dim logging.info("discriminator convolution initial channel: %d"%self.d_conv_dim) self.parallel = config.parallel logging.info("Is multi-GPU parallel: %s"%str(self.parallel)) self.gpus = config.gpus logging.info("GPU number: %s"%self.gpus) self.total_step = config.total_step logging.info("Total step: %d"%self.total_step) self.d_iters = config.d_iters self.g_iters = config.g_iters self.total_iters_ratio=config.total_iters_ratio self.num_workers = config.num_workers self.g_lr = config.g_lr logging.info("Generator learning rate: %f"%self.g_lr) self.d_lr = config.d_lr logging.info("Discriminator learning rate: %f"%self.d_lr) self.lr_decay = config.lr_decay logging.info("Learning rate decay: %f"%self.lr_decay) self.beta1 = config.beta1 logging.info("Adam opitimizer beta1: %f"%self.beta1) self.beta2 = config.beta2 logging.info("Adam opitimizer beta2: %f"%self.beta2) self.pretrained_model = config.pretrained_model self.use_pretrained_model = config.use_pretrained_model logging.info("Use pretrained model: %s"%str(self.pretrained_model)) self.use_tensorboard = config.use_tensorboard logging.info("Use tensorboard: %s"%str(self.use_tensorboard)) self.check_point_path = config.check_point_path self.sample_path = config.sample_path self.summary_path = config.summary_path self.validation_path = config.validation # val_dataloader = Validation_Data_Loader(self.validation_path,self.imsize) # self.validation_data = val_dataloader.load_validation_images() # valres_path = os.path.join(config.log_path, config.version, "valres") # if not os.path.exists(valres_path): # os.makedirs(valres_path) # self.valres_path = valres_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.prep_weights = [1.0, 1.0, 1.0, 1.0, 1.0] self.transform_loss_w = config.transform_loss_w logging.info("transform loss weight: %f"%self.transform_loss_w) self.feature_loss_w = config.feature_loss_w logging.info("feature loss weight: %f"%self.feature_loss_w) self.style_class = config.style_class self.real_prep_threshold= config.real_prep_threshold logging.info("real label threshold: %f"%self.real_prep_threshold) # self.TVLossWeight = config.TV_loss_weight # logging.info("TV loss weight: %f"%self.TVLossWeight) self.discr_success_rate = config.discr_success_rate logging.info("discriminator success rate: %f"%self.discr_success_rate) logging.info("Is conditional generating: %s"%str(config.condition_model)) self.device = torch.device('cuda:%s'%config.default_GPU if torch.cuda.is_available() else 'cpu') print('build_model...') self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.use_pretrained_model: print('load_pretrained_model...') def train(self): # Data iterator style_iter = iter(self.style_data_loader) content_iter = iter(self.content_data_loader) step_per_epoch = len(self.style_data_loader) model_save_step = int(self.model_save_step) # Fixed input for debugging # Start with trained model if self.use_pretrained_model: start = self.pretrained_model + 1 else: start = 0 alternately_iter = 0 self.d_iters = self.d_iters * self.total_iters_ratio max_alternately_iter = self.d_iters + self.total_iters_ratio * self.g_iters d_acc = 0 real_acc = 0 photo_acc = 0 fake_acc = 0 win_rate = self.discr_success_rate discr_success = self.discr_success_rate alpha = 0.05 real_labels = [] fake_labels = [] # size = [[self.batch_size,122*122],[self.batch_size,58*58],[self.batch_size,10*10],[self.batch_size,2*2],[self.batch_size,2*2]] size = [[self.batch_size,1,760,760],[self.batch_size,1,371,371],[self.batch_size,1,83,83],[self.batch_size,1,11,11],[self.batch_size,1,6,6]] for i in range(5): real_label = torch.ones(size[i], device=self.device) fake_label = torch.zeros(size[i], device=self.device) # threshold = torch.zeros(size[i], device=self.device) real_labels.append(real_label) fake_labels.append(fake_label) # Start time print('Start ====== training...') start_time = time.time() for step in range(start, self.total_step): self.Discriminator.train() self.Generator.train() # self.Decoder.train() try: content_images =next(content_iter) style_images = next(style_iter) except: style_iter = iter(self.style_data_loader) content_iter = iter(self.content_data_loader) style_images = next(style_iter) content_images = next(content_iter) style_images = style_images.to(self.device) content_images = content_images.to(self.device) # ================== Train D ================== # # Compute loss with real images if discr_success < win_rate: real_out = self.Discriminator(style_images) d_loss_real = 0 real_acc = 0 for i in range(len(real_out)): temp = self.C_loss(real_out[i],real_labels[i]).mean() real_acc += torch.gt(real_out[i],0).type(torch.float).mean() temp *= self.prep_weights[i] d_loss_real += temp real_acc /= len(real_out) d_loss_photo = 0 photo_out = self.Discriminator(content_images) photo_acc = 0 for i in range(len(photo_out)): temp = self.C_loss(photo_out[i],fake_labels[i]) photo_acc += torch.lt(photo_out[i],0).type(torch.float).mean() temp *= self.prep_weights[i] d_loss_photo += temp photo_acc /= len(photo_out) fake_image,_ = self.Generator(content_images) fake_out = self.Discriminator(fake_image.detach()) d_loss_fake = 0 fake_acc = 0 for i in range(len(fake_out)): temp = self.C_loss(fake_out[i],fake_labels[i]).mean() fake_acc += torch.lt(fake_out[i],0).type(torch.float).mean() temp *= self.prep_weights[i] d_loss_fake += temp fake_acc /= len(fake_out) d_acc = ((real_acc + photo_acc + fake_acc)/3).item() discr_success = discr_success * (1. - alpha) + alpha * d_acc # Backward + Optimize d_loss = d_loss_real + d_loss_photo + d_loss_fake self.reset_grad() d_loss.backward() self.d_optimizer.step() else: # ================== Train G ================== # # fake_image, real_feature= self.Generator(content_images) fake_feature = self.Generator(fake_image, get_feature = True) fake_out = self.Discriminator(fake_image) g_feature_loss = self.L1_loss(fake_feature,real_feature) g_transform_loss = self.MSE_loss(self.Transform(content_images),self.Transform(fake_image)) g_loss_fake = 0 g_acc = 0 for i in range(len(fake_out)): temp = self.C_loss(fake_out[i],real_labels[i]).mean() g_acc += torch.gt(fake_out[i],0).type(torch.float).mean() temp *= self.prep_weights[i] g_loss_fake += temp g_acc /= len(fake_out) g_loss_fake = g_loss_fake + g_feature_loss*self.feature_loss_w + \ g_transform_loss*self.transform_loss_w discr_success = discr_success * (1. - alpha) + alpha * (1.0 - g_acc) self.reset_grad() g_loss_fake.backward() self.g_optimizer.step() # self.decoder_optimizer.step() # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, d_out_fake: {:.4f}, g_loss_fake: {:.4f}". format(elapsed, step + 1, self.total_step, (step + 1), self.total_step , d_loss_real.item(), d_loss_fake.item(), g_loss_fake.item())) if self.use_tensorboard: self.writer.add_scalar('data/d_loss_real', d_loss_real.item(),(step + 1)) self.writer.add_scalar('data/d_loss_fake', d_loss_fake.item(),(step + 1)) self.writer.add_scalar('data/d_loss', d_loss.item(), (step + 1)) self.writer.add_scalar('data/g_loss', g_loss_fake.item(), (step + 1)) self.writer.add_scalar('data/g_feature_loss', g_feature_loss, (step + 1)) self.writer.add_scalar('data/g_transform_loss', g_transform_loss, (step + 1)) # self.writer.add_scalar('data/g_tv_loss', g_tv_loss, (step + 1)) self.writer.add_scalar('acc/real_acc', real_acc.item(), (step + 1)) self.writer.add_scalar('acc/photo_acc', photo_acc.item(), (step + 1)) self.writer.add_scalar('acc/fake_acc', fake_acc.item(), (step + 1)) self.writer.add_scalar('acc/disc_acc', d_acc, (step + 1)) self.writer.add_scalar('acc/g_acc', g_acc, (step + 1)) self.writer.add_scalar("acc/discr_success",discr_success,(step+1)) # Sample images if (step + 1) % self.sample_step == 0: print('Sample images {}_fake.png'.format(step + 1)) fake_images,_ = self.Generator(content_images) saved_image1 = torch.cat([denorm(content_images),denorm(fake_images.data)],3) saved_image2 = torch.cat([denorm(style_images),denorm(fake_images.data)],3) wocao = torch.cat([saved_image1,saved_image2],2) save_image(wocao, os.path.join(self.sample_path, '{}_fake.jpg'.format(step + 1))) # print("Transfer validation images") # num = 1 # for val_img in self.validation_data: # print("testing no.%d img"%num) # val_img = val_img.to(self.device) # fake_images,_ = self.Generator(val_img) # saved_val_image = torch.cat([denorm(val_img),denorm(fake_images)],3) # save_image(saved_val_image, # os.path.join(self.valres_path, '%d_%d.jpg'%((step+1),num))) # num +=1 # save_image(denorm(displaymask.data),os.path.join(self.sample_path, '{}_mask.png'.format(step + 1))) if (step+1) % model_save_step==0: torch.save(self.Generator.state_dict(), os.path.join(self.check_point_path , '{}_Generator.pth'.format(step + 1))) torch.save(self.Discriminator.state_dict(), os.path.join(self.check_point_path , '{}_Discriminator.pth'.format(step + 1))) # alternately_iter += 1 # alternately_iter %= max_alternately_iter def build_model(self): # code_dim=100, n_class=1000 self.Generator = Generator(chn=self.g_conv_dim, k_size= 3, res_num= self.res_num).to(self.device) self.Discriminator = Discriminator(chn=self.d_conv_dim, k_size= 3).to(self.device) self.Transform = Transform_block().to(self.device) if self.parallel: print('use parallel...') print('gpuids ', self.gpus) gpus = [int(i) for i in self.gpus.split(',')] self.Generator = nn.DataParallel(self.Generator, device_ids=gpus) self.Discriminator = nn.DataParallel(self.Discriminator, device_ids=gpus) self.Transform = nn.DataParallel(self.Transform, device_ids=gpus) # self.G.apply(weights_init) # self.D.apply(weights_init) # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.Generator.parameters()), self.g_lr, [self.beta1, self.beta2]) # self.decoder_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, # self.Decoder.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.Discriminator.parameters()), self.d_lr, [self.beta1, self.beta2]) # self.L1_loss = torch.nn.L1Loss() self.MSE_loss = torch.nn.MSELoss() self.L1_loss = torch.nn.SmoothL1Loss() self.C_loss = torch.nn.BCEWithLogitsLoss() # self.TV_loss = TVLoss(self.TVLossWeight,self.imsize,self.batch_size) # print networks logging.info("Generator structure:") logging.info(self.Generator) # print(self.Decoder) logging.info("Discriminator structure:") logging.info(self.Discriminator) def build_tensorboard(self): from tensorboardX import SummaryWriter # from logger import Logger # self.logger = Logger(self.log_path) self.writer = SummaryWriter(log_dir=self.summary_path) def load_pretrained_model(self): self.Generator.load_state_dict(torch.load(os.path.join( self.check_point_path , '{}_Generator.pth'.format(self.pretrained_model)))) self.Discriminator.load_state_dict(torch.load(os.path.join( self.check_point_path , '{}_Discriminator.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format(self.pretrained_model)) def reset_grad(self): self.g_optimizer.zero_grad() # self.decoder_optimizer.zero_grad() self.d_optimizer.zero_grad() def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
def main(args): # Step0 ==================================================================== # Set GPU ids os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids # Set the file name format FILE_NAME_FORMAT = "{0}_{1}_{2}_{3:d}{4}".format(args.model, args.dataset, args.loss, args.epochs, args.flag) # Set the results file path RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl' RESULT_FILE_PATH = os.path.join(RESULT_PATH, RESULT_FILE_NAME) # Set the checkpoint file path CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt' CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME) BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt' BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, BEST_CHECKPOINT_FILE_NAME) # Set the random seed same for reproducibility random.seed(190811) torch.manual_seed(190811) torch.cuda.manual_seed_all(190811) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Step1 ==================================================================== # Load dataset train_dataloader = CycleGAN_Dataloader(name=args.dataset, num_workers=args.num_workers) test_dataloader = CycleGAN_Dataloader(name=args.dataset, train=False, num_workers=args.num_workers) print('==> DataLoader ready.') # Step2 ==================================================================== # Make the model if args.dataset == 'cityscapes': A_generator = Generator(num_resblock=6) B_generator = Generator(num_resblock=6) A_discriminator = Discriminator() B_discriminator = Discriminator() else: A_generator = Generator(num_resblock=9) B_generator = Generator(num_resblock=9) A_discriminator = Discriminator() B_discriminator = Discriminator() # Check DataParallel available if torch.cuda.device_count() > 1: A_generator = nn.DataParallel(A_generator) B_generator = nn.DataParallel(B_generator) A_discriminator = nn.DataParallel(A_discriminator) B_discriminator = nn.DataParallel(B_discriminator) # Check CUDA available if torch.cuda.is_available(): A_generator.cuda() B_generator.cuda() A_discriminator.cuda() B_discriminator.cuda() print('==> Model ready.') # Step3 ==================================================================== # Set each loss function criterion_GAN = nn.MSELoss() criterion_cycle = nn.L1Loss() criterion_identity = nn.L1Loss() criterion_feature = nn.L1Loss() # Set each optimizer optimizer_G = optim.Adam(itertools.chain(A_generator.parameters(), B_generator.parameters()), lr=args.lr, betas=(0.5, 0.999)) optimizer_D = optim.Adam(itertools.chain(A_discriminator.parameters(), B_discriminator.parameters()), lr=args.lr, betas=(0.5, 0.999)) # Set learning rate scheduler def lambda_rule(epoch): epoch_decay = args.epochs / 2 lr_linear_scale = 1.0 - max(0, epoch + 1 - epoch_decay) \ / float(epoch_decay+ 1) return lr_linear_scale scheduler_G = lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule) scheduler_D = lr_scheduler.LambdaLR(optimizer_D, lr_lambda=lambda_rule) print('==> Criterion and optimizer ready.') # Step4 ==================================================================== # Train and validate the model start_epoch = 0 best_metric = float("inf") # Initialize the result lists train_loss_G = [] train_loss_D_A = [] train_loss_D_B = [] # Set image buffer A_buffer = ImageBuffer(args.buffer_size) B_buffer = ImageBuffer(args.buffer_size) if args.resume: assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!' checkpoint = torch.load(CHECKPOINT_FILE_PATH) A_generator.load_state_dict(checkpoint['A_generator_state_dict']) B_generator.load_state_dict(checkpoint['B_generator_state_dict']) A_discriminator.load_state_dict( checkpoint['A_discriminator_state_dict']) B_discriminator.load_state_dict( checkpoint['B_discriminator_state_dict']) optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict']) scheduler_G.load_state_dict(checkpoint['scheduler_G_state_dict']) scheduler_D.load_state_dict(checkpoint['scheduler_D_state_dict']) start_epoch = checkpoint['epoch'] train_loss_G = checkpoint['train_loss_G'] train_loss_D_A = checkpoint['train_loss_D_A'] train_loss_D_B = checkpoint['train_loss_D_B'] best_metric = checkpoint['best_metric'] # Save the training information result_data = {} result_data['model'] = args.model result_data['dataset'] = args.dataset result_data['loss'] = args.loss result_data['target_epoch'] = args.epochs result_data['batch_size'] = args.batch_size # Check the directory of the file path if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)): os.makedirs(os.path.dirname(RESULT_FILE_PATH)) if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)): os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH)) print('==> Train ready.') for epoch in range(args.epochs): # strat after the checkpoint epoch if epoch < start_epoch: continue print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs)) epoch_time = time.time() #======================================================================= # train and validate the model tloss_G, tloss_D = train( train_dataloader, A_generator, B_generator, A_discriminator, B_discriminator, criterion_GAN, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, A_buffer, B_buffer, args.loss, args.lambda_cycle, args.lambda_identity, criterion_feature, args.lambda_feature, args.attention) train_loss_G.append(tloss_G) train_loss_D_A.append(tloss_D['A']) train_loss_D_B.append(tloss_D['B']) if (epoch + 1) % 10 == 0: val(test_dataloader, A_generator, B_generator, A_discriminator, B_discriminator, epoch + 1, FILE_NAME_FORMAT, args.attention) # Update the optimizer's learning rate current_lr = optimizer_G.param_groups[0]['lr'] scheduler_G.step() scheduler_D.step() #======================================================================= current = time.time() # Save the current result result_data['current_epoch'] = epoch result_data['train_loss_G'] = train_loss_G result_data['train_loss_D_A'] = train_loss_D_A result_data['train_loss_D_B'] = train_loss_D_B # Save result_data as pkl file with open(RESULT_FILE_PATH, 'wb') as pkl_file: pickle.dump(result_data, pkl_file, protocol=pickle.HIGHEST_PROTOCOL) # Save the best checkpoint # if train_loss_G < best_metric: # best_metric = train_loss_G # torch.save({ # 'epoch': epoch+1, # 'A_generator_state_dict': A_generator.state_dict(), # 'B_generator_state_dict': B_generator.state_dict(), # 'A_discriminator_state_dict': A_discriminator.state_dict(), # 'B_discriminator_state_dict': B_discriminator.state_dict(), # 'optimizer_G_state_dict': optimizer_G.state_dict(), # 'optimizer_D_state_dict': optimizer_D.state_dict(), # 'scheduler_G_state_dict': scheduler_G.state_dict(), # 'scheduler_D_state_dict': scheduler_D.state_dict(), # 'train_loss_G': train_loss_G, # 'train_loss_D_A': train_loss_D_A, # 'train_loss_D_B': train_loss_D_B, # 'best_metric': best_metric, # }, BEST_CHECKPOINT_FILE_PATH) # Save the current checkpoint torch.save( { 'epoch': epoch + 1, 'A_generator_state_dict': A_generator.state_dict(), 'B_generator_state_dict': B_generator.state_dict(), 'A_discriminator_state_dict': A_discriminator.state_dict(), 'B_discriminator_state_dict': B_discriminator.state_dict(), 'optimizer_G_state_dict': optimizer_G.state_dict(), 'optimizer_D_state_dict': optimizer_D.state_dict(), 'scheduler_G_state_dict': scheduler_G.state_dict(), 'scheduler_D_state_dict': scheduler_D.state_dict(), 'train_loss_G': train_loss_G, 'train_loss_D_A': train_loss_D_A, 'train_loss_D_B': train_loss_D_B, 'best_metric': best_metric, }, CHECKPOINT_FILE_PATH) if (epoch + 1) % 10 == 0: CHECKPOINT_FILE_NAME_epoch = FILE_NAME_FORMAT + '_{0}.ckpt' CHECKPOINT_FILE_PATH_epoch = os.path.join( CHECKPOINT_PATH, FILE_NAME_FORMAT, CHECKPOINT_FILE_NAME_epoch) if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH_epoch)): os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH_epoch)) torch.save( { 'epoch': epoch + 1, 'A_generator_state_dict': A_generator.state_dict(), 'B_generator_state_dict': B_generator.state_dict(), 'A_discriminator_state_dict': A_discriminator.state_dict(), 'B_discriminator_state_dict': B_discriminator.state_dict(), 'optimizer_G_state_dict': optimizer_G.state_dict(), 'optimizer_D_state_dict': optimizer_D.state_dict(), 'scheduler_G_state_dict': scheduler_G.state_dict(), 'scheduler_D_state_dict': scheduler_D.state_dict(), 'train_loss_G': train_loss_G, 'train_loss_D_A': train_loss_D_A, 'train_loss_D_B': train_loss_D_B, 'best_metric': best_metric, }, CHECKPOINT_FILE_PATH_epoch) # Print the information on the console print("model : {}".format(args.model)) print("dataset : {}".format(args.dataset)) print("loss : {}".format(args.loss)) print("batch_size : {}".format(args.batch_size)) print("current lrate : {:f}".format(current_lr)) print("G loss : {:f}".format(tloss_G)) print("D A/B loss : {:f}/{:f}".format( tloss_D['A'], tloss_D['B'])) print("epoch time : {0:.3f} sec".format(current - epoch_time)) print("Current elapsed time : {0:.3f} sec".format(current - start)) print('==> Train done.') print(' '.join(['Results have been saved at', RESULT_FILE_PATH])) print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
def visualize_test_images(ckpt_list): #=========================================================================== for ckpt_name in ckpt_list: try: # Step0 ============================================================ # Parsing the hyper-parameters FILE_NAME_FORMAT = ckpt_name.split('.')[0] parsing_list = ckpt_name.split('.')[0].split('_') # Setting constants model_name = parsing_list[0] dataset_name = parsing_list[1] loss_type = parsing_list[2] flag = parsing_list[-1] if 'attention' in flag: attention = True else: attention = False # Step1 ============================================================ # Load dataset test_dataloader = CycleGAN_Dataloader(name=dataset_name, train=False, num_workers=8) print('==> DataLoader ready.') # Step2 ============================================================ # Make the model if dataset_name == 'cityscapes': A_generator = Generator(num_resblock=6) B_generator = Generator(num_resblock=6) A_discriminator = Discriminator() B_discriminator = Discriminator() else: A_generator = Generator(num_resblock=9) B_generator = Generator(num_resblock=9) A_discriminator = Discriminator() B_discriminator = Discriminator() # Check DataParallel available if torch.cuda.device_count() > 1: A_generator = nn.DataParallel(A_generator) B_generator = nn.DataParallel(B_generator) A_discriminator = nn.DataParallel(A_discriminator) B_discriminator = nn.DataParallel(B_discriminator) # Check CUDA available if torch.cuda.is_available(): A_generator.cuda() B_generator.cuda() A_discriminator.cuda() B_discriminator.cuda() print('==> Model ready.') # Step3 ============================================================ # Test the model checkpoint = torch.load(os.path.join(CHECKPOINT_PATH, ckpt_name)) A_generator.load_state_dict(checkpoint['A_generator_state_dict']) B_generator.load_state_dict(checkpoint['B_generator_state_dict']) A_discriminator.load_state_dict(checkpoint['A_discriminator_state_dict']) B_discriminator.load_state_dict(checkpoint['B_discriminator_state_dict']) train_epoch = checkpoint['epoch'] val(test_dataloader, A_generator, B_generator, A_discriminator, B_discriminator, train_epoch, FILE_NAME_FORMAT, attention) #------------------------------------------------------------------- # Print the result on the console print("model : {}".format(model_name)) print("dataset : {}".format(dataset_name)) print("loss : {}".format(loss_type)) print('-'*50) except Exception as e: print(e) print('==> Visualize test images done.')
class Model: def __init__(self, base_path='', epochs=10, learning_rate=0.0002, image_size=256, leaky_relu=0.2, betas=(0.5, 0.999), lamda=100, image_format='png'): self.image_size = image_size self.leaky_relu_threshold = leaky_relu self.epochs = epochs self.lr = learning_rate self.betas = betas self.lamda = lamda self.base_path = base_path self.image_format = image_format self.count = 1 self.gen = None self.dis = None self.gen_optim = None self.dis_optim = None self.model_type = None self.residual_blocks = 9 self.layer_size = 64 self.lr_policy = None self.lr_schedule_gen = None self.lr_schedule_dis = None self.device = self.get_device() self.create_folder_structure() def create_folder_structure(self): checkpoint_folder = self.base_path + '/checkpoints' loss_folder = self.base_path + '/Loss_Checkpoints' training_folder = self.base_path + '/Training Images' test_folder = self.base_path + '/Test Images' if not os.path.exists(checkpoint_folder): os.makedirs(checkpoint_folder) if not os.path.exists(loss_folder): os.makedirs(loss_folder) if not os.path.exists(training_folder): os.makedirs(training_folder) if not os.path.exists(test_folder): os.makedirs(test_folder) def get_device(self): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Using device:', device) print(torch.cuda.get_device_name(0)) if device.type == 'cuda': print('Memory Usage -') print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB') print('Cached: ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB') return device else: return None def initialize_model(self, lr_schedular_options, model_type='unet', residual_blocks=9, layer_size=64): all_models = ['unet', 'resnet', 'inception', 'unet2', 'unet_large', 'unet_fusion'] if model_type not in all_models: raise Exception('This model type is not available!'); self.dis = Discriminator(image_size=self.image_size, leaky_relu=self.leaky_relu_threshold) if model_type == 'unet': self.gen = Generator_Unet(image_size=self.image_size, ngf=layer_size) elif model_type == 'resnet': self.gen = Generator_RESNET(residual_blocks=residual_blocks, ngf=layer_size) elif model_type == 'inception': self.gen = Generator_InceptionNet(ngf=layer_size) elif model_type == 'unet2': self.gen = Generator_Unet_2(image_size=self.image_size, ngf=layer_size) elif model_type == 'unet_large': self.gen = Generator_Unet_Large(image_size=self.image_size, ngf=layer_size) elif model_type == 'unet_fusion': self.gen = Generator_Unet_Fusion(image_size=self.image_size, ngf=layer_size) if self.device is not None: self.gen.cuda() self.dis.cuda() self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas) self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas) self.lr_schedule_dis = self.get_learning_schedule(self.gen_optim, lr_schedular_options) self.lr_schedule_gen = self.get_learning_schedule(self.dis_optim, lr_schedular_options) self.model_type = model_type self.layer_size = layer_size self.residual_blocks = residual_blocks self.lr_policy = lr_schedular_options print('Model Initialized !\nGenerator Model Type : {} and Layer Size : {}'.format(model_type, layer_size)) print('Model Parameters are:\nEpochs : {}\nLearning rate : {}\nLeaky Relu Threshold : {}\nLamda : {}\nBeta : {}' .format(self.epochs, self.lr, self.leaky_relu_threshold, self.lamda, self.betas)) def train_model(self, trainloader, average_loss, eval=(False, None, None), save_model=(False, 25), display_test_image=(False, None, 25)): print('We will be using L1 loss with perpetual loss (L1)!') mean_loss = nn.BCELoss() l1_loss = nn.L1Loss() vgg16 = models.vgg16() vgg16_conv = nn.Sequential(*list(vgg16.children())[:-3]) self.gen.train() self.dis.train() batches = len(trainloader) print('Total number of batches in an epoch are : {}'.format(batches)) sample_img_test = None if display_test_image[0]: sample_img_test, rgb_test_images = next(iter(display_test_image[1])) save_image((rgb_test_images[0].detach().cpu() + 1) / 2, '{}/Training Images/real_img.{}'.format(self.base_path, self.image_format)) if self.device is not None: sample_img_test = sample_img_test.cuda() for i in range(self.epochs): if eval[0] and (i % eval[2] == 0): self.evaluate_L1_loss_dataset(eval[1], train=False) self.evaluate_L1_loss_dataset(trainloader, train=True) self.gen.train() running_gen_loss = 0 running_dis_loss = 0 for gray_img, real_img in trainloader: batch_size = len(gray_img) zero_label = torch.zeros(batch_size) one_label = torch.ones(batch_size) if self.device is not None: gray_img = gray_img.cuda() real_img = real_img.cuda() zero_label = zero_label.cuda() one_label = one_label.cuda() # Discriminator loss self.dis_optim.zero_grad() fake_img = self.gen(gray_img) dis_real_loss = mean_loss(self.dis(real_img), one_label) dis_fake_loss = mean_loss(self.dis(fake_img), zero_label) total_dis_loss = dis_fake_loss + dis_real_loss total_dis_loss.backward() self.dis_optim.step() # Generator loss self.gen_optim.zero_grad() fake_img = self.gen(gray_img) gen_adv_loss = mean_loss(self.dis(fake_img), one_label) gen_l1_loss = l1_loss(fake_img.view(batch_size, -1), real_img.view(batch_size, -1)) gen_pre_train = l1_loss(vgg16_conv(fake_img), vgg16_conv(real_img)) total_gen_loss = gen_adv_loss + self.lamda * gen_l1_loss + self.lamda * gen_pre_train total_gen_loss.backward() self.gen_optim.step() running_dis_loss += total_dis_loss.item() running_gen_loss += total_gen_loss.item() running_dis_loss /= (batches * 1.0) running_gen_loss /= (batches * 1.0) print('Epoch : {}, Generator Loss : {} and Discriminator Loss : {}'.format(i + 1, running_gen_loss, running_dis_loss)) if display_test_image[0] and i % display_test_image[2] == 0: self.gen.eval() out_result = self.gen(sample_img_test) out_result = out_result.detach().cpu() out_result = (out_result[0] + 1) / 2 save_image(out_result, '{}/Training Images/epoch_{}.{}'.format(self.base_path, i, self.image_format)) self.gen.train() save_tuple = ([running_gen_loss], [running_dis_loss]) average_loss.add_loss(save_tuple) if save_model[0] and i % save_model[1] == 0: self.save_checkpoint('checkpoint_epoch_{}'.format(i), self.model_type) average_loss.save('checkpoint_avg_loss', save_index=0) self.lr_schedule_gen.step() self.lr_schedule_dis.step() for param_grp in self.dis_optim.param_groups: print('Learning rate after {} epochs is : {}'.format(i + 1, param_grp['lr'])) self.save_checkpoint('checkpoint_train_final', self.model_type) average_loss.save('checkpoint_avg_loss_final', save_index=0) def get_learning_schedule(self, optimizer, option): schedular = None if option['lr_policy'] == 'linear': def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch - option['n_epochs']) / float(option['n_epoch_decay'] + 1) return lr_l schedular = lr_schedular.LambdaLR(optimizer, lr_lambda=lambda_rule) elif option['lr_policy'] == 'plateau': schedular = lr_schedular.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) elif option['lr_policy'] == 'step': schedular = lr_schedular.StepLR(optimizer, step_size=option['step_size'], gamma=0.1) elif option['lr_policy'] == 'cosine': schedular = lr_schedular.CosineAnnealingLR(optimizer, T_max=option['n_epochs'], eta_min=0) else: raise Exception('LR Policy not implemented!') return schedular def evaluate_model(self, loader, save_filename, no_of_images=1): # Considering that we have batch size of 1 for test set if self.gen is None or self.dis is None: raise Exception('Model has not been initialized and hence cannot be saved!'); counter_images_generated = 0 while counter_images_generated < no_of_images: gray, rgb = next(iter(loader)) if self.device is not None: gray = gray.cuda() filename = '{}/Test Images/{}_{}.{}'.format(self.base_path, save_filename, self.count, self.image_format) real_filename = '{}/Test Images/{}_{}_real.{}'.format(self.base_path, save_filename, self.count, self.image_format) real_gray_filename = '{}/Test Images/{}_{}_real_gray.{}'.format(self.base_path, save_filename, self.count, self.image_format) self.count += 1 self.gen.eval() out = self.gen(gray) out = out[0].detach().cpu() out = (out + 1) / 2 save_image(out, filename) gray_img = gray[0].detach().cpu() save_image(gray_img, real_gray_filename) real_img = (rgb[0].detach().cpu() + 1) / 2 save_image(real_img, real_filename) counter_images_generated += 1 def evaluate_L1_loss_dataset(self, loader, train=False): if self.gen is None or self.dis is None: raise Exception('Model has not been initialized and hence cannot be evaluated!') loss_function = nn.L1Loss() self.gen.eval() total_loss = 0.0; iterations = 0; for gray, real in loader: iterations += 1 if self.device is not None: gray = gray.cuda() real = real.cuda() gen_out = self.gen(gray) iteration_loss = loss_function(gen_out, real) total_loss += iteration_loss.item() total_loss = total_loss / (iterations * 1.0) train_test = 'test' if train: train_test = 'train' print('Total L1 loss over {} set is : {}'.format(train_test, total_loss)) return total_loss; def change_params(self, epochs=None, learning_rate=None, leaky_relu=None, betas=None, lamda=None): if epochs is not None: self.epochs = epochs print('Changed the number of epochs to {}!'.format(self.epochs)) if learning_rate is not None: self.lr = learning_rate print('Changed the learning rate to {}!'.format(self.lr)) if leaky_relu is not None: self.leaky_relu_threshold = leaky_relu print('Changed the threshold for leaky relu to {}!'.format(self.leaky_relu_threshold)) if betas is not None: self.betas = betas print('Changed the betas for Adams Optimizer!') if betas is not None or learning_rate is not None: self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas) self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas) if lamda is not None: self.lamda = lamda print('Lamda value has been changed to {}!'.format(self.lamda)) def set_all_params(self, epochs, lr, leaky_thresh, lamda, beta): self.epochs = epochs self.lr = lr self.leaky_relu_threshold = leaky_thresh self.lamda = lamda self.betas = beta self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas) self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas) print('Model Parameters are:\nEpochs : {}\nLearning rate : {}\nLeaky Relu Threshold : {}\nLamda : {}\nBeta : {}' .format(self.epochs, self.lr, self.leaky_relu_threshold, self.lamda, self.betas)) def run_model_on_dataset(self, loader, save_folder, save_path=None): if self.gen is None or self.dis is None: raise Exception('Model has not been initialized and hence cannot be saved!'); index = 1 if save_path is None: save_path = self.base_path for gray, dummy in loader: if self.device is not None: gray = gray.cuda() filename = '{}/{}/{}.{}'.format(save_path, save_folder, index, self.image_format) index += 1 self.gen.eval() out = self.gen(gray) out = out[0].detach().cpu() out = (out + 1) / 2 save_image(out, filename) def save_checkpoint(self, filename, model_type='unet'): if self.gen is None or self.dis is None: raise Exception('The model has not been initialized and hence cannot be saved !') filename = '{}/checkpoints/{}.pth'.format(self.base_path, filename) save_dict = {'model_type': model_type, 'dis_dict': self.dis.state_dict(), 'gen_dict': self.gen.state_dict(), 'lr': self.lr, 'epochs': self.epochs, 'betas': self.betas, 'image_size': self.image_size, 'leaky_relu_thresh': self.leaky_relu_threshold, 'lamda': self.lamda, 'base_path': self.base_path, 'count': self.count, 'image_format': self.image_format, 'device': self.device, 'residual_blocks': self.residual_blocks, 'layer_size': self.layer_size, 'lr_policy': self.lr_policy} torch.save(save_dict, filename) print('The model checkpoint has been saved !') def load_checkpoint(self, filename): filename = '{}/checkpoints/{}.pth'.format(self.base_path, filename) if not pathlib.Path(filename).exists(): raise Exception('This checkpoint does not exist!') self.gen = None self.dis = None save_dict = torch.load(filename) self.betas = save_dict['betas'] self.image_size = save_dict['image_size'] self.epochs = save_dict['epochs'] self.leaky_relu_threshold = save_dict['leaky_relu_thresh'] self.lamda = save_dict['lamda'] self.lr = save_dict['lr'] self.base_path = save_dict['base_path'] self.count = save_dict['count'] self.image_format = save_dict['image_format'] self.device = save_dict['device'] self.residual_blocks = save_dict['residual_blocks'] self.layer_size = save_dict['layer_size'] self.lr_policy = save_dict['lr_policy'] device = self.get_device() if device != self.device: error_msg = '' if self.device is None: error_msg = 'The model was trained on CPU and will therefore be continued on CPU only!' else: error_msg = 'The model was trained on GPU and cannot be loaded on a CPU machine!' raise Exception(error_msg) self.initialize_model(model_type=save_dict['model_type'], residual_blocks=self.residual_blocks, layer_size=self.layer_size, lr_schedular_options=self.lr_policy) self.gen.load_state_dict(save_dict['gen_dict']) self.dis.load_state_dict(save_dict['dis_dict']) print('The model checkpoint has been restored!')