def __getitem__(self, image_index): image_id = self.image_ids[image_index] impath = self.config.VOC_ROOT + '/JPEGImages/' imn = impath + image_id + '.jpg' img = Image.open(imn).convert("RGB") gt_class_mlabel = torch.from_numpy(self.label_dic[image_id]) gt_class_mlabel_bg = torch.from_numpy( np.concatenate(([1], self.label_dic[image_id]))) psan = 'prepare_labels/results/out_aff/' + image_id + '.npy' psa = np.array(list(np.load(psan).item().values())).transpose(1, 2, 0) psan = 'prepare_labels/results/out_aff_crf/' + image_id + '.npy' psa_crf = np.load(psan).transpose(1, 2, 0) h = psa.shape[0] w = psa.shape[1] saven = 'precompute/' + self.config.modelid + '/da_precompute_' + self.config.modelid + '_' + str( image_index) + '.npy' dd0 = np.load(saven).transpose(1, 2, 0) dd0 = np.reshape(cv2.resize(dd0, (w, h)), (h, w, 1)) saven = 'precompute/' + self.config.modelid + '/dk_precompute_' + self.config.modelid + '_' + str( image_index) + '.npy' dd1 = np.load(saven).transpose(1, 2, 0) dd1 = np.reshape(cv2.resize(dd1, (w, h)), (h, w, 1)) # resize inputs img_norm, img_org, psa, psa_crf, dp0, dp1 = self.img_label_resize( [img, np.array(img), psa, psa_crf, dd0, dd1]) img_org = cv2.resize(img_org, self.config.OUT_SHAPE) dd0 = cv2.resize(dd0, self.config.OUT_SHAPE) dd1 = cv2.resize(dd1, self.config.OUT_SHAPE) psa = cv2.resize(psa, self.config.OUT_SHAPE) psa_crf = cv2.resize(psa_crf, self.config.OUT_SHAPE) psa = self.get_prob_label(psa, gt_class_mlabel_bg).transpose(2, 0, 1) psa_crf = self.get_prob_label(psa_crf, gt_class_mlabel_bg).transpose(2, 0, 1) psa_mask = np.argmax(psa, 0) psa_crf_mask = np.argmax(psa_crf, 0) dd0 = torch.from_numpy(dd0).unsqueeze(0) dd1 = torch.from_numpy(dd1).unsqueeze(0) psa_mask = torch.from_numpy(psa_mask).unsqueeze(0) psa_crf_mask = torch.from_numpy(psa_crf_mask).unsqueeze(0) ignore_flags = torch.from_numpy( ssddF.get_ignore_flags(psa_mask, psa_crf_mask, [gt_class_mlabel])).float() # integration using sssdd module # the parameters are different from dssdd module (_, _, _, seed_mask) = ssddF.get_dd_mask(dd0, dd1, psa_mask, psa_crf_mask, ignore_flags, dd_bias=0.1, bg_bias=0.1) return img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, seed_mask[ 0]
def forward(self, inputs): (seg_outs_main, seg_outs_sub, seg_crf_mask, feats), seed_mask, gt_class_mlabel = inputs [x1, x2, x3, x4, x5] = feats x1 = F.avg_pool2d(x1, 2, 2) # first step seg_main, seg_prob_main, seg_mask_main, seg_head_main = seg_outs_main ignore_flags0 = torch.from_numpy( ssddF.get_ignore_flags(seg_mask_main, seg_crf_mask, gt_class_mlabel)).cuda().float() dd_head0 = self.dd_head0((seg_head_main.detach(), x1.detach())) dd00 = ssddF.get_dd(self.dd0, dd_head0, seg_mask_main) dd01 = ssddF.get_dd(self.dd0, dd_head0, seg_crf_mask) dd_outs0 = ssddF.get_dd_mask(dd00, dd01, seg_mask_main, seg_crf_mask, ignore_flags0, dd_bias=0.4, bg_bias=0) (dd01, dd10, ignore_flags0, refine_mask0) = dd_outs0 # second step seg_sub, seg_prob_sub, seg_mask_sub, seg_head_sub = seg_outs_sub dd_head1 = self.dd_head1((seg_head_sub.detach(), x1.detach())) dd10 = ssddF.get_dd(self.dd1, dd_head1, seed_mask) dd11 = ssddF.get_dd(self.dd1, dd_head1, refine_mask0) ignore_flags1 = torch.from_numpy( ssddF.get_ignore_flags(seed_mask, refine_mask0, gt_class_mlabel)).cuda().float() dd_outs1 = ssddF.get_dd_mask(dd10, dd11, seed_mask, refine_mask0, ignore_flags1, dd_bias=0.4, bg_bias=0) return dd_outs0, dd_outs1
def forward(self, inputs): (seg_outs_main, feats), psa_mask, psa_crf_mask, gt_class_mlabel = inputs [x1, x2, x3, x4, x5] = feats x1 = F.avg_pool2d(x1, 2, 2) # first step seg_main, seg_prob_main, seg_mask_main, seg_head_main = seg_outs_main ignore_flags0 = torch.from_numpy( ssddF.get_ignore_flags(psa_mask, psa_crf_mask, gt_class_mlabel)).cuda().float() dd_head0 = self.dd_head0((seg_head_main.detach(), x1.detach())) dd00 = ssddF.get_dd(self.dd0, dd_head0, psa_mask) dd01 = ssddF.get_dd(self.dd0, dd_head0, psa_crf_mask) dd_outs0 = ssddF.get_dd_mask(dd00, dd01, psa_mask, psa_crf_mask, ignore_flags0, dd_bias=0.1, bg_bias=0.1) return dd_outs0