Ejemplo n.º 1
0
    def forward(self, input):
        # print(input.shape)
        # import pdb;pdb.set_trace()
        A, A_llf1, A_llf2, A_aspp = self.expert_A.forward_SSMA(
            input[:, :3, :, :])
        B, B_llf1, B_llf2, B_aspp = self.expert_B.forward_SSMA(input[:,
                                                                     3:, :, :])
        # import pdb;pdb.set_trace()
        fused_ASPP = self.SSMA_ASPP(A_aspp, B_aspp)
        fused_skip1 = self.SSMA_skip1(A_llf1, B_llf1)
        fused_skip2 = self.SSMA_skip2(A_llf2, B_llf2)
        x = self.decoder(fused_ASPP, fused_skip1, fused_skip2)

        prob = self.softmaxMCDO(x.unsqueeze(-1))  #[batch,classes,512,512]
        prob = prob.masked_fill(prob < 1e-9, 1e-9)
        entropy, mutual_info = mutualinfo_entropy(prob)  #(batch,512,512)

        #temp_map_A, _ = self.tempnet_rgb(input[:, :3, :, :])
        #temp_map_B, _ = self.tempnet_d(input[:, 3:, :, :])

        #if self.scale_logits != None:
        #DR = self.scale_logits(entropy,mutual_info,mode=scaling_metrics) #(batch,1,1,1)
        #mean_comp = mean * torch.min(DR,comp_map.unsqueeze(1))
        #x = x * DR
        #import ipdb;ipdb.set_trace()
        #else:
        #DR = 0

        return x, entropy, mutual_info
Ejemplo n.º 2
0
    def forward(self, inputs):

        # Freeze batchnorm
        self.segnet.eval()

        # computer logits and uncertainty measures
        up1 = self.segnet.module.forwardMCDO_logits(
            inputs)  #(batch,11,512,512,passes)

        tdown1, tindices_1, tunpool_shape1 = self.temp_down1(inputs)
        tdown2, tindices_2, tunpool_shape2 = self.temp_down2(tdown1)
        tup2 = self.temp_up2(tdown2, tindices_2, tunpool_shape2)
        tup1 = self.temp_up1(tup2, tindices_1,
                             tunpool_shape1)  #[batch,1,512,512]
        temp = tup1.mean((2, 3)).unsqueeze(-1).unsqueeze(-1)  #(batch,1,1,1)

        x = up1 * tup1.unsqueeze(-1)
        mean = x.mean(-1)  #[batch,classes,512,512]
        mean = x.mean(-1)
        variance = x.std(-1)
        prob = self.softmaxMCDO(x)  #[batch,classes,512,512]
        prob = prob.masked_fill(prob < 1e-9, 1e-9)
        entropy, mutual_info = mutualinfo_entropy(prob)  #(batch,512,512)
        if self.scale_logits != None:
            mean = self.scale_logits(mean, variance, mutual_info, entropy)
        return mean, variance, entropy, mutual_info, tup1.squeeze(
            1), temp.view(-1), entropy.mean((1, 2)), mutual_info.mean((1, 2))
Ejemplo n.º 3
0
    def forward(self, inputs, mcdo=True):
        #with torch.no_grad():
        for i in range(self.mcdo_passes):
            if i == 0:
                x = self._forward(inputs, mcdo=mcdo).unsqueeze(-1)
            else:
                x = torch.cat((x, self._forward(inputs).unsqueeze(-1)), -1)

        mean = x.mean(-1)
        prob = self.softmaxMCDO(x)
        entropy, mutual_info = mutualinfo_entropy(prob)  # (batch,512,512)
        return mean, entropy, mutual_info
Ejemplo n.º 4
0
 def forward(self, input, scaling_metrics="SoftEn"):
     x, low_level_feat = self.backbone(input)
     x = self.aspp(x)
     x = self.decoder(x, low_level_feat)
     x = F.interpolate(x,
                       size=input.size()[2:],
                       mode='bilinear',
                       align_corners=True)
     x = x.unsqueeze(-1)  #[batch,classes,760,1280,1]
     mean = x.mean(-1)  #[batch,classes,760,1280]
     prob = torch.nn.Softmax(dim=1)(x)  #[batch,classes,760,1280]
     prob = prob.masked_fill(prob < 1e-9, 1e-9)
     entropy, mutual_info = mutualinfo_entropy(prob)  #(batch,760,1280)
     return mean, entropy
Ejemplo n.º 5
0
    def forwardMCDO(self, inputs, mcdo=True):
        with torch.no_grad():
            for i in range(self.mcdo_passes):
                if i == 0:
                    x = self._forward(inputs, mcdo=mcdo).unsqueeze(-1)
                else:
                    x = torch.cat((x, self.forward(inputs).unsqueeze(-1)), -1)

        mean = x.mean(-1)
        variance = x.var(-1)

        prob = self.softmaxMCDO(x)
        entropy, mutual_info = mutualinfo_entropy(prob)  # (batch,512,512)
        self.scale_logits = self._get_scale_module(scaling_module)

        return mean, variance, entropy, mutual_info
Ejemplo n.º 6
0
    def forwardMCDO_junjiao(self, inputs, mcdo=True):
        with torch.no_grad():
            for i in range(self.mcdo_passes):
                if i == 0:
                    x = self.forward(inputs, mcdo=mcdo).unsqueeze(-1)
                else:
                    x = torch.cat((x, self.forward(inputs).unsqueeze(-1)), -1)

        mean = x.mean(-1)
        variance = x.var(-1)

        prob = self.softmaxMCDO(x)
        entropy, mutual_info = mutualinfo_entropy(prob)  # (batch,512,512)
        if self.scale_logits != None:
            mean = self.scale_logits(mean, variance, mutual_info, entropy)
        return mean, variance, entropy, mutual_info, entropy.mean(
            (1, 2)), mutual_info.mean((1, 2))
Ejemplo n.º 7
0
def validate(cfg, writer, logger, logdir):
    # log git commit
    import subprocess
    label = subprocess.check_output(["git", "describe", "--always"]).strip()
    logger.info("Using commit {}".format(label))

    # Setup seeds
    random_seed(1337, True)
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Setup Dataloader
    loaders, n_classes = get_loaders(cfg["data"]["dataset"], cfg)
    # Setup Metrics
    running_metrics_val = {env: runningScore(n_classes) for env in loaders['val'].keys()}
    models = {}
    # Setup Model
    for model, attr in cfg["models"].items():
        attr = defaultdict(lambda: None, attr)
        models[model] = get_model(name=attr['arch'],
                                  n_classes=n_classes,
                                  input_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
                                  in_channels=attr['in_channels'],
                                  mcdo_passes=attr['mcdo_passes'],
                                  dropoutP=attr['dropoutP'],
                                  full_mcdo=attr['full_mcdo'],
                                  backbone=attr['backbone'],
                                  device=device).to(device)
        models[model] = torch.nn.DataParallel(models[model], device_ids=range(torch.cuda.device_count()))
        # Load pretrained weights
        model_dict = models[model].state_dict()
        model_pkl = attr['resume']
        if os.path.isfile(model_pkl):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(model_pkl)
            )
            checkpoint = torch.load(model_pkl)
            pretrained_dict = checkpoint['model_state']
            # Filter out unnecessary keys
            pretrained_dict = {k: v.resize_(model_dict[k].shape) for k, v in pretrained_dict.items() if (
                    k in model_dict)}  
            print("Model {} parameters,Loaded {} parameters".format(len(model_dict),len(pretrained_dict)))
            model_dict.update(pretrained_dict)
            models[model].load_state_dict(pretrained_dict, strict=False)
            logger.info("Loaded checkpoint '{}' (iter {})".format(model_pkl, checkpoint["epoch"]))
            print("Loaded checkpoint '{}' (iter {})".format(model_pkl, checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(model_pkl))
            print("No checkpoint found at '{}'".format(model_pkl))

    # Load training stats
    stats_dir = '/'.join(logdir.split('/')[:-1])
    prior = torch.load(os.path.join(stats_dir,'stats','prior.pkl'))
    prior = torch.tensor(prior).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device).float() # (1, n_class, 1, 1)
    entropy_stats = torch.load(os.path.join(stats_dir,'stats','entropy.pkl'))
    [models[m].eval() for m in models.keys()]
    #################################################################################
    # Validation
    #################################################################################
    print("=" * 10, "VALIDATING", "=" * 10)

    with torch.no_grad():
        for k, valloader in loaders['val'].items():
            for i_val, (input_list, labels_list) in tqdm(enumerate(valloader)):
                inputs_display, _ = parseEightCameras(input_list['rgb_display'], labels_list, input_list['d_display'], device)
                images_val = {m: input_list[m][0] for m in cfg["models"].keys()}
                labels_val = labels_list[0]
                if labels_val.shape[0] <= 1:
                    continue
                mean = {}
                entropy = {}
                val_loss = {}
                # Inference
                for m in cfg["models"].keys():
                    mean[m], entropy[m] = models[m](images_val[m])
                    mean[m] = likelihood_flattening(mean[m], cfg, entropy[m], entropy_stats, modality = m)
                mean = prior_recbalancing(mean,cfg,prior=prior)
                outputs = fusion(mean,cfg)
        
                prob, pred = outputs.max(1)
                gt = labels_val
                outputs = outputs.masked_fill(outputs < 1e-9, 1e-9)
                e, _ = mutualinfo_entropy(outputs.unsqueeze(-1))
                if i_val % cfg["training"]["png_frames"] == 0:
                    plotPrediction(logdir, cfg, n_classes, 0, i_val,  k + "/fused", inputs_display, pred, gt)
                    labels = ['entropy', 'probability']
                    values = [e, prob]
                    plotEverything(logdir, 0, i_val, k + "/fused", values, labels)

                    for m in cfg["models"].keys():
                        prob,pred_m = torch.nn.Softmax(dim=1)(mean[m]).max(1)
                        labels = [ 'entropy', 'probability']
                        values = [ entropy[m], prob]
                        plotPrediction(logdir, cfg, n_classes, 0, i_val, k + "/" + m, inputs_display, pred_m, gt)
                        plotEverything(logdir, 0, i_val, k + "/" + m, values, labels)
                    
                running_metrics_val[k].update(gt.data.cpu().numpy(), pred.cpu().numpy())
          

    for env, valloader in loaders['val'].items():
        score, class_iou, class_acc,count = running_metrics_val[env].get_scores()
        for k, v in score.items():
            logger.info('{}: {}'.format(k, v))
            writer.add_scalar('val_metrics/{}/{}'.format(env, k), v,  1)

        for k, v in class_iou.items():
            logger.info('cls_iou_{}: {}'.format(k, v))
            writer.add_scalar('val_metrics/{}/cls_iou_{}'.format(env, k), v, 1)

        for k, v in class_acc.items():
            logger.info('cls_acc_{}: {}'.format(k, v))
            writer.add_scalar('val_metrics/{}/cls_acc{}'.format(env, k), v, 1)
        running_metrics_val[env].reset()