def val_func_process(self, input_data, label, device=None): input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32) input_data = torch.FloatTensor(input_data).cuda(device) label = np.ascontiguousarray(label[None, :, :], dtype=np.int16) label = torch.LongTensor(label).cuda(device) b, h, w = label.size() scaled_gts = F.interpolate((label.view(b, 1, h, w)).float(), scale_factor=0.125, mode="nearest") b, c, h, w = scaled_gts.size() scaled_gts = scaled_gts.squeeze_().view(b, h, w) C = config.num_classes + 1 one_hot_gts = one_hot(scaled_gts, C).view(b, C, -1) similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1), one_hot_gts) with torch.cuda.device(input_data.get_device()): self.val_func.eval() self.val_func.to(input_data.get_device()) with torch.no_grad(): score = self.val_func(input_data, aux_label=similarity_gts) score = score[0] if self.is_flip: input_data = input_data.flip(-1) score_flip = self.val_func(input_data, aux_label=similarity_gts) score_flip = score_flip[0] score += score_flip.flip(-1) score = torch.exp(score) # score = score.data return score
def get_similarity_gt(gts, scale_factor): b, h, w = gts.size() scaled_gts = F.interpolate((gts.view(b, 1, h, w)).float(), scale_factor=scale_factor, mode="nearest") b, c, h, w = scaled_gts.size() scaled_gts = scaled_gts.squeeze_() C = config.num_classes + 1 one_hot_gts = one_hot(scaled_gts, C).view(b, C, -1) similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1), one_hot_gts) return similarity_gts
minibatch = dataloader.next() imgs = minibatch['data'] gts = minibatch['label'] imgs = imgs.cuda(non_blocking=True) gts = gts.cuda(non_blocking=True) b, h, w = gts.size() scaled_gts = F.interpolate((gts.view(b, 1, h, w)).float(), scale_factor=0.125, mode="nearest") b, c, h, w = scaled_gts.size() scaled_gts = scaled_gts.squeeze_() C = config.num_classes + 1 one_hot_gts = one_hot(scaled_gts, C).view(b, C, -1) similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1), one_hot_gts) gts = gts - 1 loss = model(imgs, gts, similarity_gts) # reduce the whole loss over multi-gpu if engine.distributed: dist.all_reduce(loss, dist.ReduceOp.SUM) loss = loss / engine.world_size # else: # loss = Reduce.apply(*loss) / len(loss) optimizer.zero_grad()
def forward(self, pred, target): b, h, w = target.size() scaled_gts = F.interpolate((target.view(b, 1, h, w)).float(), scale_factor=self.scale, mode="nearest") valid_mask = torch.ones_like(scaled_gts) valid_mask[scaled_gts == self.ignore_index] = 0 valid_vector = valid_mask.view(b, -1, 1) valid_mask = torch.bmm(valid_vector, valid_vector.permute(0, 2, 1)) scaled_gts[scaled_gts == self.ignore_index] = self.num_class scaled_gts = scaled_gts.squeeze_() C = self.num_class + 1 one_hot_gts = one_hot(scaled_gts, C).view(b, C, -1) similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1), one_hot_gts) bce_loss = self.criterion(pred, similarity_gts) num_valid = valid_mask.sum() num_valid = torch.where(num_valid > 0, num_valid, torch.ones(1, device=num_valid.device)) bce_loss = valid_mask * bce_loss bce_loss = bce_loss.sum() / num_valid valid_vector = valid_vector.view(b, -1) num_valid = valid_vector.sum() num_valid = torch.where(num_valid > 0, num_valid, torch.ones(1, device=num_valid.device)) vtarget = similarity_gts * valid_mask precision_part = torch.sum(pred * vtarget, dim=2) denominator = torch.sum(pred, dim=2) denominator = denominator.masked_fill_(1 - (denominator > 0), 1) precision_part = precision_part.div_(denominator) precision_label = torch.ones_like(precision_part) precision_loss = self.criterion(precision_part, precision_label) precision_loss = valid_vector * precision_loss precision_loss = precision_loss.sum() / num_valid recall_part = torch.sum(pred * vtarget, dim=2) denominator = torch.sum(vtarget, dim=2) denominator = denominator.masked_fill_(1 - (denominator > 0), 1) recall_part = recall_part.div_(denominator) recall_label = torch.ones_like(recall_part) recall_loss = self.criterion(recall_part, recall_label) recall_loss = valid_vector * recall_loss recall_loss = recall_loss.sum() / num_valid vtarget = (1 - similarity_gts) * valid_mask spec_part = torch.sum((1 - pred) * vtarget, dim=2) denominator = torch.sum(vtarget, dim=2) denominator = denominator.masked_fill_(1 - (denominator > 0), 1) spec_part = spec_part.div_(denominator) spec_label = torch.ones_like(spec_part) spec_loss = self.criterion(spec_part, spec_label) spec_loss = valid_vector * spec_loss spec_loss = spec_loss.sum() / num_valid loss = bce_loss + recall_loss + spec_loss + precision_loss return loss
minibatch = dataloader.next() imgs = minibatch['data'] gts = minibatch['label'] imgs = imgs.cuda(non_blocking=True) gts = gts.cuda(non_blocking=True) b, h, w = gts.size() # scaled_gts = F.interpolate((gts.view(b, 1, h, w)).float(), # scale_factor=0.125, # mode="nearest") # b, c, h, w = scaled_gts.size() # scaled_gts = scaled_gts.squeeze_() C = config.num_classes + 1 one_hot_gts = one_hot(gts, C).view(b, C, -1) similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1), one_hot_gts) gts = gts - 1 loss = model(imgs, gts, similarity_gts) # reduce the whole loss over multi-gpu dist.all_reduce(loss, dist.ReduceOp.SUM) optimizer.zero_grad() loss.backward() optimizer.step() current_idx = epoch * config.niters_per_epoch + idx