def forward(self, input): # print(input.shape) # import pdb;pdb.set_trace() A, A_llf1, A_llf2, A_aspp = self.expert_A.forward_SSMA( input[:, :3, :, :]) B, B_llf1, B_llf2, B_aspp = self.expert_B.forward_SSMA(input[:, 3:, :, :]) # import pdb;pdb.set_trace() fused_ASPP = self.SSMA_ASPP(A_aspp, B_aspp) fused_skip1 = self.SSMA_skip1(A_llf1, B_llf1) fused_skip2 = self.SSMA_skip2(A_llf2, B_llf2) x = self.decoder(fused_ASPP, fused_skip1, fused_skip2) prob = self.softmaxMCDO(x.unsqueeze(-1)) #[batch,classes,512,512] prob = prob.masked_fill(prob < 1e-9, 1e-9) entropy, mutual_info = mutualinfo_entropy(prob) #(batch,512,512) #temp_map_A, _ = self.tempnet_rgb(input[:, :3, :, :]) #temp_map_B, _ = self.tempnet_d(input[:, 3:, :, :]) #if self.scale_logits != None: #DR = self.scale_logits(entropy,mutual_info,mode=scaling_metrics) #(batch,1,1,1) #mean_comp = mean * torch.min(DR,comp_map.unsqueeze(1)) #x = x * DR #import ipdb;ipdb.set_trace() #else: #DR = 0 return x, entropy, mutual_info
def forward(self, inputs): # Freeze batchnorm self.segnet.eval() # computer logits and uncertainty measures up1 = self.segnet.module.forwardMCDO_logits( inputs) #(batch,11,512,512,passes) tdown1, tindices_1, tunpool_shape1 = self.temp_down1(inputs) tdown2, tindices_2, tunpool_shape2 = self.temp_down2(tdown1) tup2 = self.temp_up2(tdown2, tindices_2, tunpool_shape2) tup1 = self.temp_up1(tup2, tindices_1, tunpool_shape1) #[batch,1,512,512] temp = tup1.mean((2, 3)).unsqueeze(-1).unsqueeze(-1) #(batch,1,1,1) x = up1 * tup1.unsqueeze(-1) mean = x.mean(-1) #[batch,classes,512,512] mean = x.mean(-1) variance = x.std(-1) prob = self.softmaxMCDO(x) #[batch,classes,512,512] prob = prob.masked_fill(prob < 1e-9, 1e-9) entropy, mutual_info = mutualinfo_entropy(prob) #(batch,512,512) if self.scale_logits != None: mean = self.scale_logits(mean, variance, mutual_info, entropy) return mean, variance, entropy, mutual_info, tup1.squeeze( 1), temp.view(-1), entropy.mean((1, 2)), mutual_info.mean((1, 2))
def forward(self, inputs, mcdo=True): #with torch.no_grad(): for i in range(self.mcdo_passes): if i == 0: x = self._forward(inputs, mcdo=mcdo).unsqueeze(-1) else: x = torch.cat((x, self._forward(inputs).unsqueeze(-1)), -1) mean = x.mean(-1) prob = self.softmaxMCDO(x) entropy, mutual_info = mutualinfo_entropy(prob) # (batch,512,512) return mean, entropy, mutual_info
def forward(self, input, scaling_metrics="SoftEn"): x, low_level_feat = self.backbone(input) x = self.aspp(x) x = self.decoder(x, low_level_feat) x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) x = x.unsqueeze(-1) #[batch,classes,760,1280,1] mean = x.mean(-1) #[batch,classes,760,1280] prob = torch.nn.Softmax(dim=1)(x) #[batch,classes,760,1280] prob = prob.masked_fill(prob < 1e-9, 1e-9) entropy, mutual_info = mutualinfo_entropy(prob) #(batch,760,1280) return mean, entropy
def forwardMCDO(self, inputs, mcdo=True): with torch.no_grad(): for i in range(self.mcdo_passes): if i == 0: x = self._forward(inputs, mcdo=mcdo).unsqueeze(-1) else: x = torch.cat((x, self.forward(inputs).unsqueeze(-1)), -1) mean = x.mean(-1) variance = x.var(-1) prob = self.softmaxMCDO(x) entropy, mutual_info = mutualinfo_entropy(prob) # (batch,512,512) self.scale_logits = self._get_scale_module(scaling_module) return mean, variance, entropy, mutual_info
def forwardMCDO_junjiao(self, inputs, mcdo=True): with torch.no_grad(): for i in range(self.mcdo_passes): if i == 0: x = self.forward(inputs, mcdo=mcdo).unsqueeze(-1) else: x = torch.cat((x, self.forward(inputs).unsqueeze(-1)), -1) mean = x.mean(-1) variance = x.var(-1) prob = self.softmaxMCDO(x) entropy, mutual_info = mutualinfo_entropy(prob) # (batch,512,512) if self.scale_logits != None: mean = self.scale_logits(mean, variance, mutual_info, entropy) return mean, variance, entropy, mutual_info, entropy.mean( (1, 2)), mutual_info.mean((1, 2))
def validate(cfg, writer, logger, logdir): # log git commit import subprocess label = subprocess.check_output(["git", "describe", "--always"]).strip() logger.info("Using commit {}".format(label)) # Setup seeds random_seed(1337, True) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Dataloader loaders, n_classes = get_loaders(cfg["data"]["dataset"], cfg) # Setup Metrics running_metrics_val = {env: runningScore(n_classes) for env in loaders['val'].keys()} models = {} # Setup Model for model, attr in cfg["models"].items(): attr = defaultdict(lambda: None, attr) models[model] = get_model(name=attr['arch'], n_classes=n_classes, input_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), in_channels=attr['in_channels'], mcdo_passes=attr['mcdo_passes'], dropoutP=attr['dropoutP'], full_mcdo=attr['full_mcdo'], backbone=attr['backbone'], device=device).to(device) models[model] = torch.nn.DataParallel(models[model], device_ids=range(torch.cuda.device_count())) # Load pretrained weights model_dict = models[model].state_dict() model_pkl = attr['resume'] if os.path.isfile(model_pkl): logger.info( "Loading model and optimizer from checkpoint '{}'".format(model_pkl) ) checkpoint = torch.load(model_pkl) pretrained_dict = checkpoint['model_state'] # Filter out unnecessary keys pretrained_dict = {k: v.resize_(model_dict[k].shape) for k, v in pretrained_dict.items() if ( k in model_dict)} print("Model {} parameters,Loaded {} parameters".format(len(model_dict),len(pretrained_dict))) model_dict.update(pretrained_dict) models[model].load_state_dict(pretrained_dict, strict=False) logger.info("Loaded checkpoint '{}' (iter {})".format(model_pkl, checkpoint["epoch"])) print("Loaded checkpoint '{}' (iter {})".format(model_pkl, checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format(model_pkl)) print("No checkpoint found at '{}'".format(model_pkl)) # Load training stats stats_dir = '/'.join(logdir.split('/')[:-1]) prior = torch.load(os.path.join(stats_dir,'stats','prior.pkl')) prior = torch.tensor(prior).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device).float() # (1, n_class, 1, 1) entropy_stats = torch.load(os.path.join(stats_dir,'stats','entropy.pkl')) [models[m].eval() for m in models.keys()] ################################################################################# # Validation ################################################################################# print("=" * 10, "VALIDATING", "=" * 10) with torch.no_grad(): for k, valloader in loaders['val'].items(): for i_val, (input_list, labels_list) in tqdm(enumerate(valloader)): inputs_display, _ = parseEightCameras(input_list['rgb_display'], labels_list, input_list['d_display'], device) images_val = {m: input_list[m][0] for m in cfg["models"].keys()} labels_val = labels_list[0] if labels_val.shape[0] <= 1: continue mean = {} entropy = {} val_loss = {} # Inference for m in cfg["models"].keys(): mean[m], entropy[m] = models[m](images_val[m]) mean[m] = likelihood_flattening(mean[m], cfg, entropy[m], entropy_stats, modality = m) mean = prior_recbalancing(mean,cfg,prior=prior) outputs = fusion(mean,cfg) prob, pred = outputs.max(1) gt = labels_val outputs = outputs.masked_fill(outputs < 1e-9, 1e-9) e, _ = mutualinfo_entropy(outputs.unsqueeze(-1)) if i_val % cfg["training"]["png_frames"] == 0: plotPrediction(logdir, cfg, n_classes, 0, i_val, k + "/fused", inputs_display, pred, gt) labels = ['entropy', 'probability'] values = [e, prob] plotEverything(logdir, 0, i_val, k + "/fused", values, labels) for m in cfg["models"].keys(): prob,pred_m = torch.nn.Softmax(dim=1)(mean[m]).max(1) labels = [ 'entropy', 'probability'] values = [ entropy[m], prob] plotPrediction(logdir, cfg, n_classes, 0, i_val, k + "/" + m, inputs_display, pred_m, gt) plotEverything(logdir, 0, i_val, k + "/" + m, values, labels) running_metrics_val[k].update(gt.data.cpu().numpy(), pred.cpu().numpy()) for env, valloader in loaders['val'].items(): score, class_iou, class_acc,count = running_metrics_val[env].get_scores() for k, v in score.items(): logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}/{}'.format(env, k), v, 1) for k, v in class_iou.items(): logger.info('cls_iou_{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}/cls_iou_{}'.format(env, k), v, 1) for k, v in class_acc.items(): logger.info('cls_acc_{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}/cls_acc{}'.format(env, k), v, 1) running_metrics_val[env].reset()