class Ganomaly(object): """GANomaly Class """ @staticmethod def name(): """Return name of the class. """ return 'Ganomaly' def __init__(self, opt, dataloader=None): super(Ganomaly, self).__init__() ## # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.dataloader = dataloader self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device( "cuda:0" if self.opt.device != 'cpu' else "cpu") # -- Discriminator attributes. self.out_d_real = None self.feat_real = None self.err_d_real = None self.fake = None self.latent_i = None self.latent_o = None self.out_d_fake = None self.feat_fake = None self.err_d_fake = None self.err_d = None # -- Generator attributes. self.out_g = None self.err_g_bce = None self.err_g_l1l = None self.err_g_enc = None self.err_g = None # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load( os.path.join(self.opt.resume, 'netG.pth'))['epoch'] self.netg.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) self.netd.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) print("\tDone.\n") # print(self.netg) # print(self.netd) ## # Loss Functions self.bce_criterion = nn.BCELoss() self.l1l_criterion = nn.L1Loss() self.l2l_criterion = l2_loss ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = 1 self.fake_label = 0 ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ## def set_input(self, input): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: with torch.no_grad(): self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def update_netd(self): """ Update D network: Ladv = |f(real) - f(fake)|_2 """ ## # Feature Matching. self.netd.zero_grad() # -- # Train with real with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.real_label) self.out_d_real, self.feat_real = self.netd(self.input) # -- # Train with fake with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.fake_label) self.fake, self.latent_i, self.latent_o = self.netg(self.input) self.out_d_fake, self.feat_fake = self.netd(self.fake.detach()) # -- self.err_d = l2_loss(self.feat_real, self.feat_fake) self.err_d_real = self.err_d self.err_d_fake = self.err_d self.err_d.backward() self.optimizer_d.step() ## def reinitialize_netd(self): """ Initialize the weights of netD """ self.netd.apply(weights_init) print('Reloading d net') ## def update_netg(self): """ # ============================================================ # # (2) Update G network: log(D(G(x))) + ||G(x) - x|| # # ============================================================ # """ self.netg.zero_grad() with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.real_label) self.out_g, _ = self.netd(self.fake) self.err_g_bce = self.bce_criterion(self.out_g, self.label) self.err_g_l1l = self.l1l_criterion( self.fake, self.input) # constrain x' to look like x self.err_g_enc = self.l2l_criterion(self.latent_o, self.latent_i) self.err_g = self.err_g_bce * self.opt.w_bce + self.err_g_l1l * self.opt.w_rec + self.err_g_enc * self.opt.w_enc self.err_g.backward(retain_graph=True) self.optimizer_g.step() ## def optimize(self): """ Optimize netD and netG networks. """ self.update_netd() self.update_netg() # If D loss is zero, then re-initialize netD if self.err_d_real.item() < 1e-5 or self.err_d_fake.item() < 1e-5: self.reinitialize_netd() ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_d_real', self.err_d_real.item()), ('err_d_fake', self.err_d_fake.item()), ('err_g_bce', self.err_g_bce.item()), ('err_g_l1l', self.err_g_l1l.item()), ('err_g_enc', self.err_g_enc.item())]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netg.state_dict() }, '%s/netG.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netd.state_dict() }, '%s/netD.pth' % (weight_dir)) ## def train_epoch(self): """ Train the model for one epoch. """ self.netg.train() epoch_iter = 0 for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) self.optimize() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len( self.dataloader['train'].dataset) self.visualizer.plot_current_errors( self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name(), self.epoch + 1, self.opt.niter)) # self.visualizer.print_current_errors(self.epoch, errors) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print(">> Training model %s." % self.name()) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_epoch() res = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name()) ## def test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format( self.name().lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name()) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.dataloader['test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( error.size(0), self.opt.nz) self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( torch.max(self.an_scores) - torch.min(self.an_scores)) # auc, eer = roc(self.gt_labels, self.an_scores) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len( self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance
class Ganomaly(BaseModel): """GANomaly Class """ @property def name(self): return 'Ganomaly' def __init__(self, opt, dataloader): super(Ganomaly, self).__init__(opt, dataloader) # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) if self.opt.classifier: self.netc_i = NetC(self.opt).to(self.device) self.netc_o = NetC(self.opt).to(self.device) self.netc_i.apply(weights_init) self.netc_o.apply(weights_init) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load( os.path.join(self.opt.resume, 'netG.pth'))['epoch'] self.netg.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) self.netd.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) print("\tDone.\n") if self.opt.z_resume != '': print("\nLoading pre-trained z_networks.") self.opt.iter = torch.load( os.path.join(self.opt.z_resume, 'netC_i.pth'))['epoch'] self.netc_i.load_state_dict( torch.load(os.path.join(self.opt.z_resume, 'netC_i.pth'))['state_dict']) self.netc_o.load_state_dict( torch.load(os.path.join(self.opt.z_resume, 'netC_o.pth'))['state_dict']) print("\tDone.\n") self.l_adv = l2_loss self.l_con = nn.L1Loss() self.l_enc = l2_loss self.l_bce = nn.BCELoss() ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = torch.ones(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.fake_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) ## # Initialize input tensors for classifier self.i_input = torch.empty(size=(self.opt.batchsize, 1, self.sqrtnz, self.sqrtnz), dtype=torch.float32, device=self.device) self.o_input = torch.empty(size=(self.opt.batchsize, 1, int(self.opt.nz**0.5), int(self.opt.nz**0.5)), dtype=torch.float32, device=self.device) self.i_gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.o_gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.i_label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.o_label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.i_real_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.o_real_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.classifier: self.netc_i.train() self.netc_o.train() self.optimizer_i = optim.Adam(self.netc_i.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_o = optim.Adam(self.netc_o.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ## def forward_g(self): """ Forward propagate through netG """ self.fake, self.latent_i, self.latent_o = self.netg(self.input) ## def forward_d(self): """ Forward propagate through netD """ self.pred_real, self.feat_real = self.netd(self.input) self.pred_fake, self.feat_fake = self.netd(self.fake.detach()) ## def backward_g(self): """ Backpropagate through netG """ self.err_g_adv = self.l_adv( self.netd(self.input)[1], self.netd(self.fake)[1]) self.err_g_con = self.l_con(self.fake, self.input) self.err_g_enc = self.l_enc(self.latent_o, self.latent_i) self.err_g = self.err_g_adv * self.opt.w_adv + \ self.err_g_con * self.opt.w_con + \ self.err_g_enc * self.opt.w_enc self.err_g.backward(retain_graph=True) ## def backward_d(self): """ Backpropagate through netD """ # Real - Fake Loss self.err_d_real = self.l_bce(self.pred_real, self.real_label) self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label) # NetD Loss & Backward-Pass self.err_d = (self.err_d_real + self.err_d_fake) * 0.5 self.err_d.backward() ## def reinit_d(self): """ Re-initialize the weights of netD """ self.netd.apply(weights_init) if (self.opt.strengthen != 1): print(' Reloading net d') def optimize_params(self): """ Forwardpass, Loss Computation and Backwardpass. """ # Forward-pass self.forward_g() self.forward_d() # Backward-pass # netg self.optimizer_g.zero_grad() self.backward_g() self.optimizer_g.step() # netd self.optimizer_d.zero_grad() self.backward_d() self.optimizer_d.step() if self.err_d.item() < 1e-5: self.reinit_d() ## def save_weights_z(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netC_i.state_dict() }, '%s/netC_i.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netC_o.state_dict() }, '%s/netC_o.pth' % (weight_dir)) def forward_i(self): """ Forward propagate through netC_i """ self.pred_abn_i = self.netc_i(self.i_input) def forward_o(self): """ Forward propagate through netC_o """ self.pred_abn_o = self.netc_o(self.o_input) def backward_i(self): """ Backpropagate through netC_i """ # Real - Fake Loss self.err_i = self.l_bce(self.pred_abn_i, self.i_real_label) # NetD Loss & Backward-Pass self.err_i.backward() def backward_o(self): """ Backpropagate through netC_o """ # Real - Fake Loss self.err_o = self.l_bce(self.pred_abn_o, self.o_real_label) # NetD Loss & Backward-Pass self.err_o.backward() def reinit_i(self): """ Re-initialize the weights of netC_i """ self.netc_i.apply(weights_init) if (self.opt.strengthen != 1): print(' Reloading net i') def reinit_o(self): """ Re-initialize the weights of netC_o """ self.netc_o.apply(weights_init) if (self.opt.strengthen != 1): print(' Reloading net o') def z_optimize_params(self, net): """ Forwardpass, Loss Computation and Backwardpass. """ if net == 'i': # Forward-pass self.forward_i() # Backward-pass # netc_i self.optimizer_i.zero_grad() self.backward_i() self.optimizer_i.step() if self.err_i.item() < 1e-5: self.reinit_i() if net == 'o': # Forward-pass self.forward_o() # Backward-pass # netc_o self.optimizer_o.zero_grad() self.backward_o() self.optimizer_o.step() if self.err_o.item() < 1e-5: self.reinit_o()
class Ganomaly(BaseModel): """GANomaly Class """ @property def name(self): return 'Ganomaly' def __init__(self, opt, dataloader): super(Ganomaly, self).__init__(opt, dataloader) # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load( os.path.join(self.opt.resume, 'netG.pth'))['epoch'] self.netg.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) self.netd.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) print("\tDone.\n") self.l_adv = l2_loss self.l_con = nn.L1Loss() self.l_enc = l2_loss self.l_bce = nn.BCELoss() ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = torch.ones(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.fake_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ## def forward_g(self): """ Forward propagate through netG """ #self.fake, self.latent_i, self.latent_o = self.netg(self.input) self.fake, self.latent_i, self.latent_o = self.netg(self.input) ## def forward_d(self): """ Forward propagate through netD """ self.pred_real, self.feat_real = self.netd(self.input) self.pred_fake, self.feat_fake = self.netd(self.fake.detach()) ## def backward_g(self): """ Backpropagate through netG """ self.err_g_adv = self.l_adv( self.netd(self.input)[1], self.netd(self.fake)[1]) self.err_g_con = self.l_con(self.fake, self.input) self.err_g_enc = self.l_enc(self.latent_o, self.latent_i) self.err_g = self.err_g_adv * self.opt.w_adv + \ self.err_g_con * self.opt.w_con + \ self.err_g_enc * self.opt.w_enc self.err_g.backward(retain_graph=True) ## def backward_d(self): """ Backpropagate through netD """ # Real - Fake Loss self.err_d_real = self.l_bce(self.pred_real, self.real_label) self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label) # NetD Loss & Backward-Pass self.err_d = (self.err_d_real + self.err_d_fake) * 0.5 self.err_d.backward() ## def reinit_d(self): """ Re-initialize the weights of netD """ self.netd.apply(weights_init) print(' Reloading net d') def optimize_params(self): """ Forwardpass, Loss Computation and Backwardpass. """ # Forward-pass self.forward_g() self.forward_d() # Backward-pass # netg self.optimizer_g.zero_grad() self.backward_g() self.optimizer_g.step() # netd self.optimizer_d.zero_grad() self.backward_d() self.optimizer_d.step() if self.err_d.item() < 1e-5: self.reinit_d()
class MNIST_UNET(nn.Module): def __init__(self, opt, dataloader): super(MNIST_UNET, self).__init__() self.opt = opt #self.visualizer = Visualizer(opt) self.dataloader = dataloader self.total_steps = len(dataloader) self.device = torch.device( 'cuda:0' if self.opt.device != 'cpu' else 'cpu') self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) weights_init(self.netg) weights_init(self.netd) self.l_adv = self.l2_loss self.l_con = nn.L1Loss() self.l_enc = self.l2_loss self.l_bce = nn.BCELoss() # Initialize input tensors. self.input_imgs = torch.empty(size=(self.opt.batchsize, self.opt.nc, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) #self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) #self.gt = torch.empty(size=(self.opt.batchsize, ), dtype=torch.long, device=self.device) self.real_label = torch.ones(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.fake_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) def train(self): """ Train the model. """ ## # TRAIN self.netd.train() self.netg.train() optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) # Train for niter epochs. print(">> Train model %s steps." % self.total_steps) #if self.opt.resume != '': # netG_weights_path = os.path.join(self.opt.resume, 'netG.pth') # netD_weights_path = os.path.join(self.opt.resume, 'netD.pth') # if os.path.exists(netG_weights_path): self.step_reward = [] for step in tqdm(range(self.total_steps)): # Train for one step step_iter = 0 loss_d_step = 0 loss_g_step = 0 loss_g_adv_step = 0 loss_g_con_step = 0 loss_g_enc_step = 0 current_dataloader = [] next_batch = None for count, (input_imgs, gt) in enumerate(self.dataloader): if count < step: current_dataloader.append(input_imgs) elif count == step: next_batch = input_imgs self.set_input(next_batch) intrinsic_loss = self.calculate_intrinsic_loss() self.save_images(self.input_imgs, self.fake_imgs, step) print('step: %s, reward: %s.' % (step, intrinsic_loss)) self.step_reward.append(intrinsic_loss) current_dataloader.append(next_batch) netG_weights_path = os.path.join(self.opt.resume, 'netG.pth') netD_weights_path = os.path.join(self.opt.resume, 'netD.pth') if os.path.exists(netG_weights_path) and os.path.exists( netD_weights_path): self.netg.load_state_dict( torch.load(netG_weights_path)['state_dict']) self.netd.load_state_dict( torch.load(netD_weights_path)['state_dict']) for epoch in range(self.opt.niter): epoch_iter = 0 loss_d_epoch = 0 loss_g_epoch = 0 loss_g_adv_epoch = 0 loss_g_con_epoch = 0 loss_g_enc_epoch = 0 for input_imgs in current_dataloader: self.set_input(input_imgs) self.fake_imgs, self.latent_i, self.latent_o = self.netg( self.input_imgs) self.pred_real, self.feat_real = self.netd(self.input_imgs) self.pred_fake, self.feat_fake = self.netd( self.fake_imgs.detach()) # Update generator optimizer_g.zero_grad() self.err_g_adv = self.l_adv( self.netd(self.input_imgs)[0], self.netd(self.fake_imgs)[0]) self.err_g_con = self.l_con(self.input_imgs, self.fake_imgs) self.err_g_enc = self.l_enc(self.latent_i, self.latent_o) self.err_g = self.err_g_adv * self.opt.w_adv +\ self.err_g_con * self.opt.w_con +\ self.err_g_enc *self.opt.w_enc self.err_g.backward() optimizer_g.step() # Update discriminator optimizer_d.zero_grad() self.err_d_real = self.l_bce(self.pred_real, self.real_label) self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label) self.err_d = (self.err_d_real + self.err_d_fake) * 0.5 self.err_d.backward() optimizer_d.step() if self.err_d.item() < 1e-5: weights_init(self.netd) print('Reloading netd') self.save_weights() self.draw_reward() def calculate_intrinsic_loss(self): """ :param input_imgs: Current seen batch images :return: The calculated intrinsic rewards """ with torch.no_grad(): print(">>Geting current intrinsic reward.") netG_weights_path = os.path.join(self.opt.resume, 'netG.pth') if os.path.exists(netG_weights_path): pretrained_dict = torch.load(netG_weights_path)['state_dict'] self.netg.load_state_dict(pretrained_dict) else: weights_init(self.netg) # Creat big error tensor for the current seen batch images. self.fake_imgs, self.latent_i, self.latent_o = self.netg( self.input_imgs) if self.opt.use_con_reward: con_reward = self.l_con(self.fake_imgs, self.input_imgs) else: con_reward = 0 enc_reward = self.l_enc(self.latent_i, self.latent_o) total_reward = enc_reward + con_reward return total_reward.to('cpu').numpy().item() def l2_loss(self, input, target, size_average=True): if size_average: return torch.mean(torch.pow((input - target), 2)) else: return torch.pow((input - target), 2) def set_input(self, input_imgs): # Set input and ground truth with torch.no_grad(): self.input_imgs.resize_(input_imgs.size()).copy_(input_imgs) #self.gt.resize(gt.size()).copy_(gt) #self.label.resize_(gt.size()) def save_weights(self): weight_dir = os.path.join(self.opt.resume, 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({'state_dict': self.netg.state_dict()}, os.path.join(weight_dir, 'netG.pth')) torch.save({'state_dict': self.netd.state_dict()}, os.path.join(weight_dir, 'netD.pth')) def save_images(self, real, fake, step): N, C, W, H = real.shape stitch_images = np.zeros((C, W * N, 3 * H)) image_dir = os.path.join(self.opt.resume, 'images') if not os.path.exists(image_dir): os.makedirs(image_dir) for i in range(N): real_img = (real[i, :, :, :] * 255).to('cpu').numpy().astype( np.int) fake_img = (fake[i, :, :, :] * 255).to('cpu').numpy().astype( np.int) mask_img = np.abs(real_img - fake_img).astype(np.uint8) print(np.min(real_img)) print(np.max(real_img)) stitch_images[:, W * i:W * i + W, :H] = real_img.astype(np.uint8) stitch_images[:, W * i:W * i + W, H:2 * H] = fake_img.astype(np.uint8) print(np.min(fake_img)) print(np.max(fake_img)) stitch_images[:, W * i:W * i + W, 2 * H:] = mask_img stitch_images = stitch_images.squeeze(0) #stitch_images = stitch_images.numpy() plt.imsave(os.path.join(image_dir, '%s.png' % (step + 1)), stitch_images, cmap='gray') def draw_reward(self): plt.figure(figsize=(10, 5)) plt.plot(range(1, self.total_steps + 1), self.step_reward) plt.xlabel("steps") plt.ylabel("reward") plt.savefig(os.path.join(self.opt.resume, 'images', 'rewards.png')) plt.show()