Example #1
0
    def __getitem__(self, i):
        x_bot, x_top, y_bot, y_top = self._get_crop_coords()
        x_slice = slice(x_bot, x_top)
        y_slice = slice(y_bot, y_top)

        img = self.img_dset[self.start_index + i]
        src = helpers.to_tensor(img[0, x_slice, y_slice], device=self.device)
        tgt = helpers.to_tensor(img[1, x_slice, y_slice], device=self.device)
        src[src < -4] = 0
        tgt[tgt < -4] = 0

        field = None
        if self.field_dset is not None:
            full_field = self.field_dset[self.start_index + i]
            field = full_field[..., x_bot:x_top, y_bot:y_top]
            field = helpers.to_tensor(field, device=self.device)

        bundle = {
            "src": src,
            "tgt": tgt,
            "src_zeros": src == 0,
            "tgt_zeros": tgt == 0,
            "id": helpers.to_tensor(i, device=self.device),
        }
        if field is not None:
            bundle["src_field"] = field
        return bundle
Example #2
0
    def _upsample_field_direct(self, name, stage, in_mip, out_mip):
        """Upsample a field datset from SRC_MIP to DST_MIP

        Args:
            name (str): filename of dataset (input & output)
            stage (int): module which created dataset (input & output)
            in_mip (int): MIP level of the input dataset
            out_mip (int): MIP level of the output dataset
        """
        assert in_mip > out_mip
        in_field_dset = self.get_field_dset(name=name, stage=stage, mip=in_mip)
        dst_img_dset = self.get_img_dset(name=name, mip=out_mip)
        out_field_dset = self.load_field_dset(name=name,
                                              stage=stage,
                                              mip=out_mip,
                                              shape=dst_img_dset.shape,
                                              create=True)
        # scale_factor = 2**(in_mip - out_mip)

        with torch.no_grad():
            hsz = (in_field_dset.shape[-1] - dst_img_dset.shape[-1]) // 2
            for n in range(in_field_dset.shape[0]):
                in_field = helpers.to_tensor(in_field_dset[n:n + 1]).field()
                in_field = in_field * (2**in_mip)
                out_field = in_field.up(mips=in_mip - out_mip)
                out_field = out_field / (2**out_mip)
                # out_field = torch.nn.functional.interpolate(in_field,
                #                                  mode='bilinear',
                #                                  scale_factor=scale_factor,
                #                                  align_corners=False,
                #                                  recompute_scale_factor=False
                #                                  ) * scale_factor
                out_field_cropped = out_field[0, :, hsz:-hsz, hsz:-hsz]
                out_field_dset[n] = helpers.get_np(out_field_cropped[...])
Example #3
0
    def _upsample_field(self, name, stage, mip_start, mip_end):
        for src_mip in range(mip_start, mip_end - 1, -1):
            tgt_mip = src_mip - 1

            src_field_dset = self.get_field_dset(name=name,
                                                 stage=stage,
                                                 mip=src_mip)
            tgt_img_dset = self.get_img_dset(name=name, mip=tgt_mip)
            tgt_field_dset = self.load_field_dset(name=name,
                                                  stage=stage,
                                                  mip=tgt_mip,
                                                  shape=tgt_img_dset.shape,
                                                  create=True)

            with torch.no_grad():
                tgt_size = tgt_img_dset.shape[-1]

                for b in range(src_field_dset.shape[0]):
                    field_data = helpers.to_tensor(src_field_dset[b:b + 1])
                    field_data_ups = torch.nn.functional.interpolate(
                        field_data,
                        mode='bilinear',
                        scale_factor=2.0,
                        align_corners=False,
                        recompute_scale_factor=False) * 2.0
                    field_data_ups_cropped = field_data_ups[:, :, :tgt_size, :
                                                            tgt_size]
                    tgt_field_dset[b] = helpers.get_np(
                        field_data_ups_cropped[...])
Example #4
0
    def _generate_field_dataset(self, model, img_dset, field_dset,
                                prev_field_dset):
        for b in range(img_dset.shape[0]):
            src = helpers.to_tensor(img_dset[b, 0])
            tgt = helpers.to_tensor(img_dset[b, 1])

            if prev_field_dset is not None:
                prev_field = helpers.to_tensor(prev_field_dset[b])
            else:
                prev_field = None

            field = model(src_img=src,
                          tgt_img=tgt,
                          src_agg_field=prev_field,
                          train=False,
                          return_state=False)
            field_dset[b] = helpers.get_np(field)
Example #5
0
    def _generate_field_dataset(self, model, img_dset, field_dset,
                                prev_field_dset):
        """Generate predicted source field for src, tgt pairs
        Pairs are permutable, so predict field for src,tgt as well as tgt,src
        """
        for b in range(img_dset.shape[0]):
            src = helpers.to_tensor(img_dset[b, 0])
            tgt = helpers.to_tensor(img_dset[b, 1])
            if prev_field_dset is not None:
                prev_field = helpers.to_tensor(prev_field_dset[b])
            else:
                prev_field = None

            field = model(
                src_img=src,
                tgt_img=tgt,
                src_agg_field=prev_field,
                train=False,
                return_state=False,
            )
            field_dset[b] = helpers.get_np(field)
Example #6
0
    def __getitem__(self, i):
        x_bot, x_top, y_bot, y_top = self._get_crop_coords()

        img = self.img_dset[self.start_index + i]
        src = helpers.to_tensor(img[..., 0, x_bot:x_top, y_bot:y_top])
        tgt = helpers.to_tensor(img[..., 1, x_bot:x_top, y_bot:y_top])
        src[src < -4] = 0
        tgt[tgt < -4] = 0

        field = None
        if self.field_dset is not None:
            full_field = self.field_dset[self.start_index + i]
            field = helpers.to_tensor(full_field[..., x_bot:x_top,
                                                 y_bot:y_top])

        bundle = {
            "src": src,
            "tgt": tgt,
            "src_zeros": src == 0,
            "tgt_zeros": tgt == 0,
        }
        if field is not None:
            bundle["src_field"] = field
        return bundle
Example #7
0
def generate_shard(rank, world_size, module_path, checkpoint_name, img_path,
                   prev_field_path, dst_dir, src_mip, dst_mip):
    """Generate field for subset of image pairs associated with rank

    Args:
        rank (int): process order
        world_size (int): total no. of processes
        module_path (str): path to modelhouse directory
        checkpoint_name (str): checkpoint for weights
        img_path (str): path to image pairs h5
        prev_field_path (str): path to previous fields h5
        dst_dir (str): path where temporary field h5s will be stored
        src_mip (int)
        dst_mip (int)
    """
    print(f"Running DDP on rank {rank}.")
    setup(rank, world_size)
    torch.cuda.set_device(rank)
    model = modelhouse.load_model_simple(module_path,
                                         finetune=True,
                                         finetune_lr=3e-1,
                                         finetune_sm=300e0,
                                         finetune_iter=200,
                                         pass_field=True,
                                         checkpoint_name=checkpoint_name)
    checkpoint_path = os.path.join(module_path, "model")
    model.aligner.net = model.aligner.net.to(rank)
    model = DDP(model, device_ids=[rank])

    img_dset = h5py.File(img_path, 'r')['main']
    prev_field_dset = h5py.File(prev_field_path, 'r')['main']
    assert (img_dset.shape[0] >= world_size)
    n = img_dset.shape[0] // world_size
    n_start = rank * n
    n_stop = min(n_start + n, img_dset.shape[0])
    if rank + 1 == world_size:
        n_stop = img_dset.shape[0]
    src_mip_filepath = os.path.join(dst_dir, '{}'.format(src_mip))
    dst_mip_filepath = os.path.join(dst_dir, '{}'.format(dst_mip))
    src_field_dset = CloudVolume(src_mip_filepath, mip=0)
    dst_field_dset = CloudVolume(dst_mip_filepath, mip=0)
    # src_field_dset = src_field.create_dataset("main",
    #                                        shape=field_shape,
    #                                        dtype=np.float32,
    #                                        chunks=chunks,
    #                                        compression='lzf',
    #                                        scaleoffset=2)
    # dst_field_dset = dst_field.create_dataset("main",
    #                                        shape=field_shape,
    #                                        dtype=np.float32,
    #                                        chunks=chunks,
    #                                        compression='lzf',
    #                                        scaleoffset=2)

    for b in range(n_start, n_stop):
        print('{} / {}'.format(img_dset.shape[0], b))
        src = helpers.to_tensor(img_dset[b, 0])
        tgt = helpers.to_tensor(img_dset[b, 1])
        if prev_field_dset is not None:
            prev_field = helpers.to_tensor(prev_field_dset[b])
        else:
            prev_field = None

        field = model(src_img=src,
                      tgt_img=tgt,
                      src_agg_field=prev_field,
                      train=False,
                      return_state=False)
        field_shape = field.shape
        hsz = (src_field_dset.shape[0] * 2**(src_mip - dst_mip) -
               dst_field_dset.shape[0]) // 2
        src_field_dset[:, :, b - n_start, :] = helpers.get_np(
            field.permute(2, 3, 0, 1))
        # upsample
        field = field * (2**src_mip)
        field = field.up(mips=src_mip - dst_mip)
        field = field / (2**dst_mip)
        field_cropped = field[:, :, hsz:-hsz, hsz:-hsz]
        field_cropped = field_cropped.permute(2, 3, 0, 1)
        dst_field_dset[:, :, b - n_start, :] = helpers.get_np(field_cropped)

    cleanup()
    pass