def _run_batch(self, data, eval=False, ds=None): time_dataloading = time.time() - self.iter_starttime time_proc_start = time.time() iter_stats = {'time_dataloading': time_dataloading} batch = Batch(data, eval=eval, gpu=self.args.gpu) targets = batch.masks.float() images = batch.images self.net.zero_grad() with torch.set_grad_enabled(not eval): X_vessels = self.net(images) loss = F.binary_cross_entropy(X_vessels, targets) loss.backward() self.optimizer.step() # statistics iter_stats.update({ 'loss': loss.item(), 'epoch': self.epoch, 'timestamp': time.time(), 'iter_time': time.time() - self.iter_starttime, 'time_processing': time.time() - time_proc_start, 'iter': self.iter_in_epoch, 'total_iter': self.total_iter, 'batch_size': len(batch) }) self.iter_starttime = time.time() self.epoch_stats.append(iter_stats) # print stats every N mini-batches if self._is_printout_iter(eval): nimgs = 5 avg_PR = eval_vessels.calculate_metrics(X_vessels, targets)['PR'] PRs = get_image_PRs(X_vessels[:nimgs], targets[:nimgs]) iter_stats.update({'avg_PR': avg_PR}) self._print_iter_stats( self.epoch_stats[-self._print_interval(eval):]) # Batch visualization if self.args.show: retina_vis.visualize_vessels(images, images, vessel_hm=targets, scores=PRs, pred_vessel_hm=X_vessels, ds=ds, wait=self.args.wait, f=1.0, overlay_heatmaps_recon=True, nimgs=1, horizontal=True)
from csl_common.utils.common import init_random from csl_common.utils.ds_utils import build_transform from csl_common.vis import vis import config init_random(3) path = config.get_dataset_paths('wflw')[0] ds = WFLW(root=path, train=False, deterministic=True, use_cache=False, daug=0, image_size=256, transform=build_transform(deterministic=False, daug=0)) ds.filter_labels({'pose': 1, 'occlusion': 0, 'make-up': 1}) dl = td.DataLoader(ds, batch_size=10, shuffle=False, num_workers=0) print(ds) for data in dl: batch = Batch(data, gpu=False) images = vis.to_disp_images(batch.images, denorm=True) # lms = lmutils.convert_landmarks(to_numpy(batch.landmarks), lmutils.LM98_TO_LM68) lms = batch.landmarks images = vis.add_landmarks_to_images(images, lms, draw_wireframe=False, color=(0, 255, 0), radius=3) vis.vis_square(images, nCols=10, fx=1., fy=1., normalize=False)
def get_fixed_samples(ds, num): dl = td.DataLoader(ds, batch_size=num, shuffle=False, num_workers=0) data = next(iter(dl)) return Batch(data, n=num)
def evaluate(self): log.info("") log.info("Evaluating '{}'...".format(self.session_name)) # log.info("") self.iters_per_epoch = len(self.fixed_val_data) self.iter_in_epoch = 0 self.iter_starttime = time.time() epoch_starttime = time.time() epoch_stats = [] self.net.eval() for data in self.fixed_val_data: batch = Batch(data, eval=True) targets = batch.masks.float() time_proc_start = time.time() time_dataloading = time.time() - self.iter_starttime with torch.no_grad(): X_vessels = self.net(batch.images) loss = F.binary_cross_entropy(X_vessels, targets) iter_stats = { 'loss': loss.item(), 'epoch': self.epoch, 'timestamp': time.time(), 'time_dataloading': time_dataloading, 'time_processing': time.time() - time_proc_start, 'iter_time': time.time() - self.iter_starttime, 'iter': self.iter_in_epoch, 'total_iter': self.total_iter, 'batch_size': len(batch) } epoch_stats.append(iter_stats) if self._is_printout_iter(eval=True): nimgs = 1 avg_PR = eval_vessels.calculate_metrics(X_vessels, targets)['PR'] PRs = get_image_PRs(X_vessels[:nimgs], targets[:nimgs]) iter_stats.update({'avg_PR': avg_PR}) self._print_iter_stats( epoch_stats[-self._print_interval(True):]) # # Batch visualization # if self.args.show: retina_vis.visualize_vessels(batch.images, batch.images, vessel_hm=targets, scores=PRs, pred_vessel_hm=X_vessels, wait=self.args.wait, f=1.0, overlay_heatmaps_recon=True, nimgs=nimgs, horizontal=True) self.iter_starttime = time.time() self.iter_in_epoch += 1 # print average loss and accuracy over epoch self._print_epoch_summary(epoch_stats, epoch_starttime) # update scheduler means = pd.DataFrame(epoch_stats).mean().to_dict() val_loss = means['loss'] val_PR = means['avg_PR']
def _run_batch(self, data, eval=False, ds=None): time_dataloading = time.time() - self.iter_starttime time_proc_start = time.time() iter_stats = {'time_dataloading': time_dataloading} batch = Batch(data, eval=eval) X_target = batch.target_images if batch.target_images is not None else batch.images self.saae.zero_grad() loss = torch.zeros(1, requires_grad=True).cuda() ####################### # Encoding ####################### with torch.set_grad_enabled(self.args.train_encoder): z_sample = self.saae.Q(batch.images) ########################### # Encoding regularization ########################### if (not eval or self._is_printout_iter(eval) ) and self.args.with_zgan and self.args.train_encoder: if WITH_LOSS_ZREG: loss_zreg = torch.abs(z_sample).mean() loss += loss_zreg iter_stats.update({'loss_zreg': loss_zreg.item()}) encoding = self.update_encoding(z_sample) iter_stats.update(encoding) iter_stats['z_recon_mean'] = z_sample.mean().item() iter_stats['z_recon_std'] = z_sample.std().item() ####################### # Decoding ####################### if not self.args.train_encoder: z_sample = z_sample.detach() with torch.set_grad_enabled(self.args.train_decoder): # reconstruct images X_recon = self.saae.P(z_sample) ####################### # Reconstruction loss ####################### loss_recon = aae_training.loss_recon(X_target, X_recon) loss = loss_recon * self.args.w_rec iter_stats['loss_recon'] = loss_recon.item() ####################### # Structural loss ####################### cs_error_maps = None if self.args.with_ssim_loss or eval: store_cs_maps = self._is_printout_iter( eval) or eval # get error maps for visualization loss_ssim, cs_error_maps = aae_training.loss_struct( X_target, X_recon, self.ssim, calc_error_maps=store_cs_maps) loss_ssim *= self.args.w_ssim loss = 0.5 * loss + 0.5 * loss_ssim iter_stats['ssim_torch'] = loss_ssim.item() ####################### # Adversarial loss ####################### if self.args.with_gan and self.args.train_decoder and self.iter_in_epoch % 1 == 0: gan_stats, loss_G = self.update_gan( X_target, X_recon, z_sample, train=not eval, with_gen_loss=self.args.with_gen_loss) loss += loss_G iter_stats.update(gan_stats) iter_stats['loss'] = loss.item() if self.args.train_decoder: loss.backward() # Update auto-encoder if not eval: if self.args.train_encoder: self.optimizer_E.step() if self.args.train_decoder: self.optimizer_G.step() if eval or self._is_printout_iter(eval): iter_stats['ssim'] = aae_training.calc_ssim(X_target, X_recon) # statistics iter_stats.update({ 'epoch': self.epoch, 'timestamp': time.time(), 'iter_time': time.time() - self.iter_starttime, 'time_processing': time.time() - time_proc_start, 'iter': self.iter_in_epoch, 'total_iter': self.total_iter, 'batch_size': len(batch) }) self.iter_starttime = time.time() self.epoch_stats.append(iter_stats) # print stats every N mini-batches if self._is_printout_iter(eval): self._print_iter_stats( self.epoch_stats[-self._print_interval(eval):]) # # Batch visualization # if self.args.show: num_sample_images = { 128: 8, 256: 7, 512: 2, 1024: 1, } nimgs = num_sample_images[self.args.input_size] self.visualize_random_images(nimgs, z_real=z_sample) self.visualize_interpolations(z_sample, nimgs=2) self.visualize_batch(batch, X_recon, nimgs=nimgs, ssim_maps=cs_error_maps, ds=ds, wait=self.wait)
def _run_batch(self, data, eval=False, ds=None): time_dataloading = time.time() - self.iter_starttime time_proc_start = time.time() iter_stats = {'time_dataloading': time_dataloading} batch = Batch(data, eval=eval) self.saae.zero_grad() self.saae.eval() input_images = batch.target_images if batch.target_images is not None else batch.images with torch.set_grad_enabled(self.args.train_encoder): z_sample = self.saae.Q(input_images) iter_stats.update({'z_recon_mean': z_sample.mean().item()}) ####################### # Reconstruction phase ####################### with torch.set_grad_enabled(self.args.train_encoder and not eval): X_recon = self.saae.P(z_sample) # calculate reconstruction error for debugging and reporting with torch.no_grad(): iter_stats['loss_recon'] = aae_training.loss_recon( batch.images, X_recon) ####################### # Landmark predictions ####################### train_lmhead = not eval lm_preds_max = None with torch.set_grad_enabled(train_lmhead): self.saae.LMH.train(train_lmhead) X_lm_hm = self.saae.LMH(self.saae.P) if batch.lm_heatmaps is not None: loss_lms = F.mse_loss(batch.lm_heatmaps, X_lm_hm) * 100 * 3 iter_stats.update({'loss_lms': loss_lms.item()}) if eval or self._is_printout_iter(eval): # expensive, so only calculate when every N iterations # X_lm_hm = lmutils.decode_heatmap_blob(X_lm_hm) X_lm_hm = lmutils.smooth_heatmaps(X_lm_hm) lm_preds_max = self.saae.heatmaps_to_landmarks(X_lm_hm) if eval or self._is_printout_iter(eval): lm_gt = to_numpy(batch.landmarks) nmes = lmutils.calc_landmark_nme( lm_gt, lm_preds_max, ocular_norm=self.args.ocular_norm, image_size=self.args.input_size) # nccs = lmutils.calc_landmark_ncc(batch.images, X_recon, lm_gt) iter_stats.update({'nmes': nmes}) if train_lmhead: # if self.args.train_encoder: # loss_lms = loss_lms * 80.0 loss_lms.backward() self.optimizer_lm_head.step() if self.args.train_encoder: self.optimizer_E.step() # self.optimizer_G.step() # statistics iter_stats.update({ 'epoch': self.epoch, 'timestamp': time.time(), 'iter_time': time.time() - self.iter_starttime, 'time_processing': time.time() - time_proc_start, 'iter': self.iter_in_epoch, 'total_iter': self.total_iter, 'batch_size': len(batch) }) self.iter_starttime = time.time() self.epoch_stats.append(iter_stats) # print stats every N mini-batches if self._is_printout_iter(eval): self._print_iter_stats( self.epoch_stats[-self._print_interval(eval):]) lmvis.visualize_batch( batch.images, batch.landmarks, X_recon, X_lm_hm, lm_preds_max, lm_heatmaps=batch.lm_heatmaps, target_images=batch.target_images, ds=ds, ocular_norm=self.args.ocular_norm, clean=False, overlay_heatmaps_input=False, overlay_heatmaps_recon=False, landmarks_only_outline=self.landmarks_only_outline, landmarks_no_outline=self.landmarks_no_outline, f=1.0, wait=self.wait)