def test(self):
        r"""Inference function.
        """
        if self.cfg.INFERENCE.DO_EVAL:
            self.model.eval()
        else:
            self.model.train()
        volume_id = 0

        ww = blend_gaussian(self.cfg.MODEL.OUTPUT_SIZE)
        NUM_OUT = self.cfg.MODEL.OUT_PLANES
        pad_size = self.cfg.DATASET.PAD_SIZE
        if len(self.cfg.DATASET.PAD_SIZE) == 3:
            pad_size = [
                self.cfg.DATASET.PAD_SIZE[0], self.cfg.DATASET.PAD_SIZE[0],
                self.cfg.DATASET.PAD_SIZE[1], self.cfg.DATASET.PAD_SIZE[1],
                self.cfg.DATASET.PAD_SIZE[2], self.cfg.DATASET.PAD_SIZE[2]
            ]

        if ("super" in self.cfg.MODEL.ARCHITECTURE):
            output_size = np.array(
                self.dataloader._dataset.input_size) * np.array(
                    self.cfg.DATASET.SCALE_FACTOR).tolist()
            result = [
                np.stack(
                    [np.zeros(x, dtype=np.float32) for _ in range(NUM_OUT)])
                for x in output_size
            ]
            weight = [np.zeros(x, dtype=np.float32) for x in output_size]
        else:
            result = [
                np.stack(
                    [np.zeros(x, dtype=np.float32) for _ in range(NUM_OUT)])
                for x in self.dataloader._dataset.input_size
            ]
            weight = [
                np.zeros(x, dtype=np.float32)
                for x in self.dataloader._dataset.input_size
            ]

        # build test-time augmentor and update output filename
        output_filename = self.cfg.INFERENCE.OUTPUT_NAME
        if self.cfg.INFERENCE.AUG_NUM != 0:
            test_augmentor = TestAugmentor(self.cfg.INFERENCE.AUG_MODE,
                                           self.cfg.INFERENCE.AUG_NUM)
            output_filename = test_augmentor.update_name(output_filename)

        start = time.time()
        sz = tuple([NUM_OUT] + list(self.cfg.MODEL.OUTPUT_SIZE))
        with torch.no_grad():
            for _, (pos, volume) in enumerate(self.dataloader):
                volume_id += self.cfg.INFERENCE.SAMPLES_PER_BATCH
                print('volume_id:', volume_id)

                # for gpu computing
                volume = torch.from_numpy(volume).to(self.device)
                if not self.cfg.INFERENCE.DO_3D:
                    volume = volume.squeeze(1)

                if self.cfg.INFERENCE.AUG_NUM != 0:
                    output = test_augmentor(self.model, volume)
                else:
                    output = self.model(volume).cpu().detach().numpy()

                if self.cfg.INFERENCE.MODEL_OUTPUT_ID[
                        0] is not None:  # select channel, self.cfg.INFERENCE.MODEL_OUTPUT_ID is a list [None]
                    output = output[self.cfg.INFERENCE.MODEL_OUTPUT_ID[0]]
                if not "super" in self.cfg.MODEL.ARCHITECTURE:
                    for idx in range(output.shape[0]):
                        st = pos[idx]
                        result[st[0]][:, st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                        st[3]:st[3]+sz[3]] += output[idx] * np.expand_dims(ww, axis=0)
                        weight[st[0]][st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                        st[3]:st[3]+sz[3]] += ww
                else:
                    for idx in range(output.shape[0]):
                        st = pos[idx]
                        st = (np.array(st) *
                              np.array([1] + self.cfg.DATASET.SCALE_FACTOR)
                              ).tolist()
                        result[st[0]][:, st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                        st[3]:st[3]+sz[3]] += output[idx] * np.expand_dims(ww, axis=0)
                        weight[st[0]][st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                        st[3]:st[3]+sz[3]] += ww

        end = time.time()
        print("Prediction time:", (end - start))

        for vol_id in range(len(result)):
            if result[vol_id].ndim > weight[vol_id].ndim:
                weight[vol_id] = np.expand_dims(weight[vol_id], axis=0)
            # For segmentation masks, use uint16
            result[vol_id] = (result[vol_id] / weight[vol_id] * 255).astype(
                np.uint8)
            sz = result[vol_id].shape
            result[vol_id] = result[vol_id][:, pad_size[0]:sz[1] - pad_size[1],
                                            pad_size[2]:sz[2] - pad_size[3],
                                            pad_size[4]:sz[3] - pad_size[5]]

        if self.output_dir is None:
            return result
        else:
            print('save h5')
            writeh5(os.path.join(self.output_dir, output_filename), result,
                    ['vol%d' % (x) for x in range(len(result))])
    def test(self):
        r"""Inference function.
        """
        if self.cfg.INFERENCE.DO_EVAL:
            self.model.eval()
        else:
            self.model.train()

        ww = build_blending_matrix(self.cfg.MODEL.OUTPUT_SIZE,
                                   self.cfg.INFERENCE.BLENDING)
        if self.cfg.INFERENCE.MODEL_OUTPUT_ID[0] is None:
            NUM_OUT = self.cfg.MODEL.OUT_PLANES
        else:
            NUM_OUT = len(self.cfg.INFERENCE.MODEL_OUTPUT_ID)
        pad_size = self.cfg.DATASET.PAD_SIZE
        if len(self.cfg.DATASET.PAD_SIZE) == 3:
            pad_size = [
                self.cfg.DATASET.PAD_SIZE[0], self.cfg.DATASET.PAD_SIZE[0],
                self.cfg.DATASET.PAD_SIZE[1], self.cfg.DATASET.PAD_SIZE[1],
                self.cfg.DATASET.PAD_SIZE[2], self.cfg.DATASET.PAD_SIZE[2]
            ]

        if ("super" in self.cfg.MODEL.ARCHITECTURE):
            output_size = np.array(
                self.dataloader._dataset.volume_size) * np.array(
                    self.cfg.DATASET.SCALE_FACTOR).tolist()
            result = [
                np.stack(
                    [np.zeros(x, dtype=np.float32) for _ in range(NUM_OUT)])
                for x in output_size
            ]
            weight = [np.zeros(x, dtype=np.float32) for x in output_size]
        else:
            result = [
                np.stack(
                    [np.zeros(x, dtype=np.float32) for _ in range(NUM_OUT)])
                for x in self.dataloader._dataset.volume_size
            ]
            weight = [
                np.zeros(x, dtype=np.float32)
                for x in self.dataloader._dataset.volume_size
            ]

        # build test-time augmentor and update output filename
        test_augmentor = TestAugmentor(mode=self.cfg.INFERENCE.AUG_MODE,
                                       do_2d=self.cfg.DATASET.DO_2D,
                                       num_aug=self.cfg.INFERENCE.AUG_NUM)
        self.inference_output_name = test_augmentor.update_name(
            self.inference_output_name)

        start = time.time()
        sz = tuple([NUM_OUT] + list(self.cfg.MODEL.OUTPUT_SIZE))
        print("Total number of batches: ", len(self.dataloader))

        volume_id = 0
        with torch.no_grad():
            for _, (pos, volume) in enumerate(self.dataloader):
                volume_id += self.cfg.INFERENCE.SAMPLES_PER_BATCH
                print('progress: %d/%d' % (volume_id, len(self.dataloader)))

                # for gpu computing
                volume = torch.from_numpy(volume).to(self.device)
                if not self.cfg.INFERENCE.DO_3D:
                    volume = volume.squeeze(1)

                # forward pass
                output = test_augmentor(self.model, volume)
                # select channel, self.cfg.INFERENCE.MODEL_OUTPUT_ID is a list [None]
                if self.cfg.INFERENCE.MODEL_OUTPUT_ID[0] is not None:
                    ndim = output.ndim
                    output = output[:, self.cfg.INFERENCE.MODEL_OUTPUT_ID[0]]
                    if ndim - output.ndim == 1:
                        output = output[:, None, :]
                if not "super" in self.cfg.MODEL.ARCHITECTURE:
                    for idx in range(output.shape[0]):
                        st = pos[idx]
                        if result[st[0]].ndim - output[idx].ndim == 1:
                            result[st[0]][:, st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                                          st[3]:st[3]+sz[3]] += output[idx][:,None,:] * ww[None,:]
                        else:
                            result[st[0]][:, st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                                        st[3]:st[3]+sz[3]] += output[idx] * ww[None,:]
                        weight[st[0]][st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                        st[3]:st[3]+sz[3]] += ww
                else:
                    for idx in range(output.shape[0]):
                        st = pos[idx]
                        st = (np.array(st) *
                              np.array([1] + self.cfg.DATASET.SCALE_FACTOR)
                              ).tolist()
                        result[st[0]][:, st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                        st[3]:st[3]+sz[3]] += output[idx] * np.expand_dims(ww, axis=0)
                        weight[st[0]][st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], \
                        st[3]:st[3]+sz[3]] += ww

        end = time.time()
        print("Prediction time:", (end - start))

        for vol_id in range(len(result)):
            if result[vol_id].ndim > weight[vol_id].ndim:
                weight[vol_id] = np.expand_dims(weight[vol_id], axis=0)
            # For segmentation masks, use uint16
            result[vol_id] = (result[vol_id] / weight[vol_id] * 255).astype(
                np.uint8)
            sz = result[vol_id].shape
            result[vol_id] = result[vol_id][:, pad_size[0]:sz[1] - pad_size[1],
                                            pad_size[2]:sz[2] - pad_size[3],
                                            pad_size[4]:sz[3] - pad_size[5]]

        if self.output_dir is None:
            return result
        else:
            print('Saving as h5...')
            writeh5(os.path.join(self.output_dir, self.inference_output_name),
                    result, ['vol%d' % (x) for x in range(len(result))])
            print('Inference is done!')
Beispiel #3
0
    # 3D segmentation
    out = waterz.waterz(affs=aff, thresholds=T_thres2,
                        fragments=out, merge_function=mf)[0]

    et = time.time()

    sn = '%s_%f_%f_%d_%f_%d_%f_%d_%f_%s.h5' % (
        args.seg_mode, T_aff[0], T_aff[1], T_thres[0], T_aff[2], T_dust, T_merge, T_aff_rel, T_thres2[0], mf)

else:
    print('The segmentation method is not implemented yet!')
    raise NotImplementedError

# print time profile
print('time: %.1f s' % ((et-st)))

# ARAND evaluation
score = adapted_rand(out.astype(np.uint32), seg)
print('Adaptive rand: ', score)
# 0: 0.22
# 1: 0.098
# 2: 0.137

# save segmentation
if args.save:
    result_dir = os.path.dirname(args.pd) + '/'
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    writeh5(result_dir + sn, out, 'main')

python demo.py -p ../pytorch_connectomics/outputs/cerebellum_P0/test/seg2.h5 -gt ../pytorch_connectomics/outputs/cerebellum_P0/train/seg2_gt.h5 -ph ../pytorch_connectomics/outputs/cerebellum_P0/test/aff2.h5