Exemplo n.º 1
0
    def _forward(self, imgs, gts):
        """
            cur_img: B * 1 * h * w * d
            cur_ins: B * h * w * d
            labels_pad: max_ins_length * B
        """
        batch_size = imgs.size(0)
        ins, seeds, labels_pad, infos = gts

        full_seg_gt = (ins > 0).long()

        rrb = torch.cat(self.featureHead(imgs), dim=1)

        full_seg_output = self.segHead(rrb)
        # feature = self.fusenet(rrb)

        _, full_seg_pred = torch.max(full_seg_output, dim=1)

        full_seg_array = full_seg_gt.detach().cpu().numpy()
        full_seg_pred_array = full_seg_pred.detach().cpu().numpy()
        metric_full_seg = self.metric_func(full_seg_pred_array, full_seg_array)

        loss = self.criterion(full_seg_output.float(), full_seg_gt)

        # direction field
        direct_field = batch_direct_field3D(full_seg_array)

        loss_ins = 0.
        num_seeds = seeds.shape[0]
        metrics_fov = []
        for seed, label in zip(seeds, labels_pad):
            # fov_feature = get_batchCropData(seed, feature.detach().cpu().numpy(), self.fov_shape)
            seg_map = get_batchCropData(seed,
                                        np.expand_dims(full_seg_pred_array, 1),
                                        self.fov_shape)
            seed_map = np.expand_dims(
                get_batchSeedMap(seed, label,
                                 ins.detach().cpu().numpy(), self.fov_shape),
                1)
            df = get_batchCropData(seed, direct_field, self.fov_shape)

            fov_ins = get_batchCropData(seed,
                                        ins.detach().cpu().numpy(),
                                        self.fov_shape)
            if np.any(label == 0):
                label_ = np.array([l + 1 if l == 0 else l for l in label])
                try:
                    fov_gt = (fov_ins == label_[:, None, None,
                                                None]).astype(int)
                except AttributeError:
                    pdb.set_trace()
            else:
                fov_gt = (fov_ins == label[:, None, None, None]).astype(int)
            fov_gt = torch.from_numpy(fov_gt).to(self.device)

            # fov_feature = torch.from_numpy(fov_feature).to(device=self.device, dtype=torch.float)
            seg_map = torch.from_numpy(seg_map).to(device=self.device,
                                                   dtype=torch.float)
            seed_map = torch.from_numpy(seed_map).to(device=self.device,
                                                     dtype=torch.float)
            df = torch.from_numpy(df).to(device=self.device, dtype=torch.float)

            # fov_data = torch.cat([fov_feature, seg_map, df, seed_map], dim=1)
            fov_data = torch.cat([seg_map, df, seed_map], dim=1)

            fov_output = self.ffn(fov_data)

            curr_loss = self.criterion(fov_output, fov_gt)
            loss_ins += curr_loss

            _, fov_seg = torch.max(fov_output, dim=1)
            fov_seg_array = fov_seg.detach().cpu().numpy()
            fov_gt_array = fov_gt.detach().cpu().numpy()
            metric_seg = self.metric_func(fov_seg_array, fov_gt_array)
            metrics_fov.append(metric_seg)

            # 寻找异常loss
            # if curr_loss > 3:
            #     # curr_loss.item()
            #     log = []
            #     for i_, (i0, i1, i2, i3) in enumerate(zip(infos[0], infos[1], infos[2], infos[3])):
            #         log.append([i0, [i1.item(), i2.item(), i3.item()]] +\
            #                [seed[i_].tolist()] + [label[i_]])  # ['tifname', [patch_seed], [fov_seed], label]
            #     print("*"*10)
            #     with open(os.path.join(self.logs_path, "odd_error.txt"), 'a') as f:
            #         print(log, "loss:", curr_loss.item(), file=f)

        loss_ins /= num_seeds
        loss = 0.3 * loss + 0.7 * loss_ins

        metric_fov = np.mean(metrics_fov, axis=0)

        self.dice_ins.update(metric_fov[0], batch_size)
        self.prec_ins.update(metric_fov[1], batch_size)
        self.recall_ins.update(metric_fov[2], batch_size)

        self.dice_patch.update(metric_full_seg[0], batch_size)
        self.prec_patch.update(metric_full_seg[1], batch_size)
        self.recall_patch.update(metric_full_seg[2], batch_size)

        dices = [self.dice_patch, self.dice_ins]
        precs = [self.prec_patch, self.prec_ins]
        recalls = [self.recall_patch, self.recall_ins]

        return loss, [dices, precs, recalls], [
            full_seg_pred_array, full_seg_array, fov_seg_array, fov_gt_array
        ]
Exemplo n.º 2
0
    def forward_seedfromfullseg(self, imgs, gts):
        """
            cur_img: B * 1 * h * w * d
            cur_ins: B * h * w * d
            labels_pad: max_ins_length * B
        """
        batch_size = imgs.size(0)
        ins, seeds, labels_pad = gts

        full_seg_gt = (ins > 0).long()

        rrb = torch.cat(self.featureHead(imgs), dim=1)

        full_seg_output = self.segHead(rrb)
        feature = self.fusenet(rrb)

        _, full_seg_pred = torch.max(full_seg_output, dim=1)

        full_seg_array = full_seg_gt.detach().cpu().numpy()
        full_seg_pred_array = full_seg_pred.detach().cpu().numpy()
        metric_full_seg = self.metric_func(full_seg_pred_array, full_seg_array)

        loss = self.criterion(full_seg_output, full_seg_gt)

        loss_ins = 0.
        num_seeds = seeds.shape[0]
        metrics_fov = []
        seeds_center, seeds_coords, labels = sample_seedfromfullseg(
            full_seg_pred_array,
            ins.detach().cpu().numpy(), labels_pad, self.fov_shape)
        for seed, label, seed_coord in zip(seeds_center, labels, seeds_coords):
            fov_feature = get_batchCropData(seed,
                                            feature.detach().cpu().numpy(),
                                            self.fov_shape)
            seg_map = get_batchCropData(seed,
                                        np.expand_dims(full_seg_pred_array, 1),
                                        self.fov_shape)
            seed_map = np.expand_dims(
                get_batchSeedMap(seed, label,
                                 ins.detach().cpu().numpy(), self.fov_shape),
                1)

            fov_ins = get_batchCropData(seed,
                                        ins.detach().cpu().numpy(),
                                        self.fov_shape)
            if np.any(label == 0):
                label_ = np.array([l + 1 if l == 0 else l for l in label])
                try:
                    fov_gt = (fov_ins == label_[:, None, None,
                                                None]).astype(int)
                except AttributeError:
                    pdb.set_trace()
            else:
                fov_gt = (fov_ins == label[:, None, None, None]).astype(int)
            fov_gt = torch.from_numpy(fov_gt).to(self.device)

            fov_feature = torch.from_numpy(fov_feature).to(self.device).float()
            seg_map = torch.from_numpy(seg_map).to(self.device).float()
            seed_map = torch.from_numpy(seed_map).to(self.device).float()
            try:
                fov_data = torch.cat([fov_feature, seg_map, seed_map], dim=1)
            except RuntimeError:
                pdb.set_trace()
            fov_output = self.ffn(fov_data)

            loss_ins += self.criterion(fov_output, fov_gt)

            _, fov_seg = torch.max(fov_output, dim=1)
            fov_seg_array = fov_seg.detach().cpu().numpy()
            fov_gt_array = fov_gt.detach().cpu().numpy()
            metric_seg = self.metric_func(fov_seg_array, fov_gt_array)
            metrics_fov.append(metric_seg)

        loss_ins /= num_seeds
        loss += loss_ins

        metric_fov = np.mean(metrics_fov, axis=0)

        self.dice_ins.update(metric_fov[0], batch_size)
        self.prec_ins.update(metric_fov[1], batch_size)
        self.recall_ins.update(metric_fov[2], batch_size)

        self.dice_patch.update(metric_full_seg[0], batch_size)
        self.prec_patch.update(metric_full_seg[1], batch_size)
        self.recall_patch.update(metric_full_seg[2], batch_size)

        dices = [self.dice_patch, self.dice_ins]
        precs = [self.prec_patch, self.prec_ins]
        recalls = [self.recall_patch, self.recall_ins]

        return loss, [dices, precs, recalls], [
            full_seg_pred_array, full_seg_array, fov_seg_array, fov_gt_array
        ]
Exemplo n.º 3
0
    def _forward(self, imgs, gts):

        batch_size = imgs.size(0)
        ins, seeds, labels_pad, _ = gts

        full_seg_gt = (ins > 0).long()

        rrb = torch.cat(self.featureHead(imgs), dim=1)
        full_seg_output = self.segHead(rrb)
        feature = self.fusenet(rrb)

        _, full_seg_pred = torch.max(full_seg_output, dim=1)

        full_seg_array = full_seg_gt.detach().cpu().numpy()
        full_seg_pred_array = full_seg_pred.detach().cpu().numpy()
        metric_full_seg = self.metric_func(full_seg_pred_array, full_seg_array)

        emb = self.embHead(rrb)

        loss = self.criterion(full_seg_output, full_seg_gt) +\
               self.dsc_loss(emb, ins)

        loss_ins = 0.
        num_seeds = seeds.shape[0]
        metrics_fov = []
        emb_array = emb.unsqueeze(-1).transpose(
            1, -1).squeeze(1).cpu().detach().numpy()
        ins_array = ins.cpu().detach().numpy()

        for seed, label in zip(seed, labels_pad):
            fov_feature = get_batchCropData(seed,
                                            feature.detach().cpu().numpy(),
                                            self.fov_shape)
            seg_map = get_batchCropData(seed,
                                        np.expand_dims(full_seg_pred_array, 1),
                                        self.fov_shape)
            # seed_map = np.expand_dims(get_batchSeedMap(seed, label, ins.detach().cpu().numpy(), self.fov_shape), 1)

            fov_gt = get_batchCropData(seed,
                                       ins.detach().cpu().numpy(),
                                       self.fov_shape)
            anchors = self.anchor_sel(emb_array, ins_array,
                                      label.cpu().numpy())

            match = self.matchHead(emb, emb,
                                   anchors)  # 根据anchor点, 找出一组相似点, 代表性点

            fs = torch.cat(
                [match, anchors[:, None].float(), seg_map[:, None].float()],
                dim=1)

            fov_output = self.ffn(fs)

            curr_loss = self.criterion(fov_output, fov_gt) +\
                        self.l2_loss(match, fov_gt[:, None].float())
            loss_ins += curr_loss

            _, fov_seg = torch.max(fov_output, dim=1)
            fov_seg_array = fov_seg.detach().cpu().numpy()
            fov_gt_array = fov_gt.detach().cpu().numpy()
            metric_seg = self.metric_func(fov_seg_array, fov_gt_array)
            metrics_fov.append(metric_seg)

        loss_ins /= num_seeds
        loss += loss_ins

        metric_fov = np.mean(metrics_fov, axis=0)

        self.dice_ins.update(metric_fov[0], batch_size)
        self.prec_ins.update(metric_fov[1], batch_size)
        self.recall_ins.update(metric_fov[2], batch_size)

        self.dice_patch.update(metric_full_seg[0], batch_size)
        self.prec_patch.update(metric_full_seg[1], batch_size)
        self.recall_patch.update(metric_full_seg[2], batch_size)

        dices = [self.dice_patch, self.dice_ins]
        precs = [self.prec_patch, self.prec_ins]
        recalls = [self.recall_patch, self.recall_ins]

        return loss, [dices, precs, recalls], [
            full_seg_pred_array, full_seg_array, fov_seg_array, fov_gt_array
        ]
Exemplo n.º 4
0
    def _forward(self, imgs, gts):
        """
            cur_img: B * 1 * h * w * d
            cur_ins: B * h * w * d
            labels_pad: max_ins_length * B
        """
        batch_size = imgs.size(0)
        ins, seeds, labels_pad = gts

        full_seg_gt = (ins > 0).long()

        rrb = torch.cat(self.featureHead(imgs), dim=1)

        full_seg_output = self.segHead(rrb)
        # feature = self.fusenet(rrb)

        _, full_seg_pred = torch.max(full_seg_output, dim = 1)

        full_seg_array = full_seg_gt.detach().cpu().numpy()
        full_seg_pred_array = full_seg_pred.detach().cpu().numpy()
        metric_full_seg = self.metric_func(full_seg_pred_array, full_seg_array)

        loss = self.criterion(full_seg_output, full_seg_gt)

        # direction field
        direct_field = batch_direct_field3D(full_seg_array)

        loss_ins = 0.
        num_seeds = seeds.shape[0]
        metrics_fov = []
        for seed, label in zip(seeds, labels_pad):
            # fov_feature = get_batchCropData(seed, feature.detach().cpu().numpy(), self.fov_shape)
            seg_map = get_batchCropData(seed, np.expand_dims(full_seg_pred_array, 1), self.fov_shape)
            seed_map = np.expand_dims(get_batchSeedMap(seed, label, ins.detach().cpu().numpy(), self.fov_shape), 1)
            df = get_batchCropData(seed, direct_field, self.fov_shape)

            fov_ins = get_batchCropData(seed, ins.detach().cpu().numpy(), self.fov_shape)
            fov_gt = (fov_ins == label[:, None, None, None]).astype(int)
            fov_gt = torch.from_numpy(fov_gt).to(self.device)

            # fov_feature = torch.from_numpy(fov_feature).to(self.device).float()
            seg_map = torch.from_numpy(seg_map).to(self.device).float()
            seed_map = torch.from_numpy(seed_map).to(self.device).float()
            df = torch.from_numpy(df).to(self.device).float()

            # fov_data = torch.cat([fov_feature, seg_map, df, seed_map], dim=1)
            fov_data = torch.cat([seg_map, df, seed_map], dim=1)

            fov_output = self.ffn(fov_data)

            loss_ins += self.criterion(fov_output, fov_gt)

            _, fov_seg = torch.max(fov_output, dim=1)
            fov_seg_array = fov_seg.detach().cpu().numpy()
            fov_gt_array = fov_gt.detach().cpu().numpy()
            metric_seg = self.metric_func(fov_seg_array, fov_gt_array)
            metrics_fov.append(metric_seg)

        loss_ins /= num_seeds
        loss += loss_ins

        metric_fov = np.mean(metrics_fov, axis=0)

        self.dice_ins.update(metric_fov[0], batch_size)
        self.prec_ins.update(metric_fov[1], batch_size)
        self.recall_ins.update(metric_fov[2], batch_size)

        self.dice_patch.update(metric_full_seg[0], batch_size)
        self.prec_patch.update(metric_full_seg[1], batch_size)
        self.recall_patch.update(metric_full_seg[2], batch_size)

        dices = [self.dice_patch, self.dice_ins]
        precs = [self.prec_patch, self.prec_ins]
        recalls = [self.recall_patch, self.recall_ins]

        return loss, [dices, precs, recalls], [full_seg_pred_array, full_seg_array, fov_seg_array, fov_gt_array]