class SCOPSTrainer(object): def __init__(self, args, model): self.args = args self.model = model # Initialize spatial/color transform for Equuivariance loss. self.tps = RandTPS(args.input_size[1], args.input_size[0], batch_size=args.batch_size, sigma=args.tps_sigma, border_padding=args.eqv_border_padding, random_mirror=args.eqv_random_mirror, random_scale=(args.random_scale_low, args.random_scale_high), mode=args.tps_mode).cuda(args.gpu) # Color Transorm. self.cj_transform = transforms.Compose([ transforms.ToPILImage(), transforms.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.2, hue=0.2), transforms.ToTensor(), ]) # KL divergence loss for equivariance self.kl = nn.KLDivLoss().cuda(args.gpu) # loss/ bilinear upsampling self.interp = nn.Upsample( size=(args.input_size[1], args.input_size[0]), mode='bilinear', align_corners=True) # Initialize feature extractor and part basis for the semantic consistency loss. self.zoo_feat_net = FeatureExtraction( feature_extraction_cnn=args.ref_net, normalization=args.ref_norm, last_layer=args.ref_layer) self.zoo_feat_net.eval() self.part_basis_generator = PartBasisGenerator(self.zoo_feat_net.out_dim, args.num_parts, normalize=args.ref_norm) self.part_basis_generator.cuda(args.gpu) self.part_basis_generator.train() if args.restore_part_basis != '': self.part_basis_generator.load_state_dict( {'w': torch.load(args.restore_part_basis)}) # Initialize optimizers. self.optimizer_seg = optim.SGD(self.model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) self.optimizer_seg.zero_grad() self.optimizer_sc = optim.SGD(self.part_basis_generator.parameters( ), lr=args.learning_rate_w, momentum=args.momentum, weight_decay=args.weight_decay) self.optimizer_sc.zero_grad() # visualizor self.viz = Visualizer(args) def step(self, batch, current_step): loss_con_value = 0 loss_eqv_value = 0 loss_lmeqv_value = 0 loss_sc_value = 0 loss_orthonamal_value = 0 self.optimizer_seg.zero_grad() self.optimizer_sc.zero_grad() adjust_learning_rate(self.optimizer_seg, current_step, self.args) images_cpu = batch['img'] labels = batch['saliency'] if 'saliency' in batch.keys() else None edges = batch['edge'] if 'edge' in batch.keys() else None gts = batch['gt'] if 'gt' in batch.keys() else None landmarks = batch['landmarks'] if 'landmarks' in batch.keys() else None bbox = batch['bbox'] if 'bbox' in batch.keys() else None images = images_cpu.cuda(self.args.gpu) feature_instance, feature_part, pred_low = self.model(images) pred = self.interp(pred_low) # prepare for torch model_zoo models images zoo_mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)) zoo_var = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)) images_zoo_cpu = (images_cpu.numpy() + IMG_MEAN.reshape((1, 3, 1, 1))) / 255.0 images_zoo_cpu -= zoo_mean images_zoo_cpu /= zoo_var images_zoo_cpu = torch.from_numpy(images_zoo_cpu) images_zoo = images_zoo_cpu.cuda(self.args.gpu) with torch.no_grad(): zoo_feats = self.zoo_feat_net(images_zoo) zoo_feat = torch.cat([self.interp(zoo_feat) for zoo_feat in zoo_feats], dim=1) # saliency masking if not self.args.no_sal_masking and labels is not None: zoo_feat = zoo_feat * \ labels.unsqueeze(dim=1).expand_as( zoo_feat).cuda(self.args.gpu) loss_sc = loss.semantic_consistency_loss( features=zoo_feat, pred=pred, basis=self.part_basis_generator()) loss_sc_value += self.args.lambda_sc * loss_sc.data.cpu().numpy() # orthonomal_loss loss_orthonamal = loss.orthonomal_loss(self.part_basis_generator()) loss_orthonamal_value += self. args.lambda_orthonormal * \ loss_orthonamal.data.cpu().numpy() # Concentratin Loss loss_con = loss.concentration_loss(pred) loss_con_value += self.args.lambda_con * loss_con.data.cpu().numpy() # Equivariance Loss images_cj = torch.from_numpy( ((images_cpu.numpy() + IMG_MEAN.reshape((1, 3, 1, 1))) / 255.0).clip(0, 1.0)) for b in range(images_cj.shape[0]): images_cj[b] = torch.from_numpy(self.cj_transform( images_cj[b]).numpy() * 255.0 - IMG_MEAN.reshape((1, 3, 1, 1))) images_cj = images_cj.cuda() self.tps.reset_control_points() images_tps = self.tps(images_cj) feature_instance_tps, feature_part_tps, pred_low_tps = self.model( images_tps) pred_tps = self.interp(pred_low_tps) pred_d = pred.detach() pred_d.requires_grad = False # no padding in the prediction space pred_tps_org = self.tps(pred_d, padding_mode='zeros') loss_eqv = self.kl(F.log_softmax(pred_tps, dim=1), F.softmax(pred_tps_org, dim=1)) loss_eqv_value += self.args.lambda_eqv * loss_eqv.data.cpu().numpy() centers_tps = utils.batch_get_centers(nn.Softmax(dim=1)(pred_tps)[:, 1:, :, :]) pred_tps_org_dif = self.tps(pred, padding_mode='zeros') centers_tps_org = utils.batch_get_centers(nn.Softmax( dim=1)(pred_tps_org_dif)[:, 1:, :, :]) loss_lmeqv = F.mse_loss(centers_tps, centers_tps_org) loss_lmeqv_value += self.args.lambda_lmeqv * loss_lmeqv.data.cpu().numpy() # visualization if current_step % self.args.vis_interval == 0: with torch.no_grad(): pred_softmax = nn.Softmax(dim=1)(pred) part_softmax = pred_softmax[:, 1:, :, :] # normalize part_softmax /= part_softmax.max(dim=3, keepdim=True)[ 0].max(dim=2, keepdim=True)[0] self.viz.vis_images(current_step, images_cpu, images_tps.cpu( ), labels, edges, IMG_MEAN, pred.float()) self.viz.vis_part_heatmaps( current_step, part_softmax, threshold=0.1, prefix='pred') if landmarks is not None: self.viz.vis_landmarks(current_step, images_cpu, IMG_MEAN, pred, landmarks) if bbox is not None: self.viz.vis_bboxes(current_step, bbox) print('saving part basis') torch.save({'W': self.part_basis_generator().detach().cpu(), 'W_state_dict': self.part_basis_generator.state_dict()}, osp.join(self.args.snapshot_dir, self.args.exp_name, 'BASIS_' + str(current_step) + '.pth')) self.viz.vis_losses(current_step, [self.part_basis_generator.w.mean(), self.part_basis_generator.w.std()], [ 'part_basis_mean', 'part_basis_std']) # sum all loss terms total_loss = self.args.lambda_con * loss_con \ + self.args.lambda_eqv * loss_eqv \ + self.args.lambda_lmeqv * loss_lmeqv \ + self.args.lambda_sc * loss_sc \ + self.args.lambda_orthonormal * loss_orthonamal total_loss.backward() # visualize loss curves self.viz.vis_losses(current_step, [loss_con_value, loss_eqv_value, loss_lmeqv_value, loss_sc_value, loss_orthonamal_value], ['loss_con', 'loss_eqv', 'loss_lmeqv', 'loss_sc', 'loss_orthonamal']) # clip gradients nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_gradients) self.optimizer_seg.step() nn.utils.clip_grad_norm_( self.part_basis_generator.parameters(), self.args.clip_gradients) self.optimizer_sc.step() print('exp = {}'.format(osp.join(self.args.snapshot_dir, self.args.exp_name))) print(('iter = {:8d}/{:8d}, ' + 'loss_con = {:.3f}, ' + 'loss_eqv = {:.3f}, ' + 'loss_lmeqv = {:.3f}, ' + 'loss_sc = {:.3f}, ' + 'loss_orthonamal = {:.3f}') .format(current_step, self.args.num_steps, loss_con_value, loss_eqv_value, loss_lmeqv_value, loss_sc_value, loss_orthonamal_value))
viz.vis_inputs(epoch, inputs_tf, prefix='_tf') viz.vis_inputs(epoch, saliency, prefix='saliency') viz.vis_DFF_heatmaps(epoch, pams_viz, threshold=0.1, prefix='pams') viz.vis_DFF_heatmaps(epoch, pams_tf_viz, threshold=0.1, prefix='pams_tf') if phase == 'train': viz.vis_losses(epoch, [ loss, loss_cls, loss_geo, loss_sem, loss_ort, loss_eqv ], [ 'loss_train', 'loss_cls_train', 'loss_geo_train', 'loss_sem_train', 'loss_ort_train', 'loss_eqv_train' ]) else: viz.vis_losses(epoch, [ loss, loss_cls, loss_geo, loss_sem, loss_ort, loss_eqv ], [ 'loss_val', 'loss_cls_val', 'loss_geo_val', 'loss_sem_val', 'loss_ort_val', 'loss_eqv_val' ]) # Backward + optimize only in training phase if phase == 'train': loss.backward()