def show_side_results(dataset_name, dataset_version=0):
    network_strings = [
        'unet_2_channels__64_dims__wz_64',
        'unet_2_channels__64_dims__wz_64_offset', 'METHOD_double'
    ]

    string_of_names = ['Gt', 'U Net 12 Embs', 'Multi Offset', 'Proposed']

    data = {}
    counter = 1
    for network in network_strings:
        directory = "results/" + dataset_name + "/" + "v_" + str(
            dataset_version) + "/" + network + "/mini_inst"
        print("reading " + directory)
        results = tensors_io.load_volume(directory)
        data[counter] = results
        counter += 1
    data[0] = tensors_io.load_volume("results/" + dataset_name + "/" + "v_" +
                                     str(dataset_version) + "/gt")
    '''
    tensors_io.save_subplots_compare(data[0], data[1], data[2], data[3], "RESULTS_COMPARISON", string_of_names=string_of_names)

    og = tensors_io.load_volume("results/" + dataset_name + "/" + "v_" + str(dataset_version) + "/original")
    og = torch.transpose(og, 3, 2)
    tensors_io.save_subvolume(torch.transpose(data[0], 3, 2), 'compare/gt')
    tensors_io.save_subvolume(og, 'compare/side_og')
    tensors_io.save_subvolume(torch.transpose(data[1], 3, 2), 'compare/side_unet')
    tensors_io.save_subvolume(torch.transpose(data[2], 3, 2), 'compare/side_unet_of')
    tensors_io.save_subvolume(torch.transpose(data[3], 3, 2), 'compare/side_proposed')
    '''
    multi = tensors_io.load_volume(
        "results/" + dataset_name + "/" + "v_" + str(dataset_version) +
        "/unet_double_multi_2_channels__64_dims__wz_64/mini_inst")
    tensors_io.save_subvolume(torch.transpose(multi, 1, 3), 'compare/multi')
def recrop(params_t, resample_masks=0, factor=1, directory='.'):
    if (directory != '.'):
        path_of_interest = directory + '/subsampled'
        resample_masks = 2
        masks = tensors_io.load_volume(directory, scale=2).unsqueeze(0)
        tensors_io.save_subvolume(masks, path_of_interest)
        exit()
    else:
        params_t.cleaning = False
        params_t.cleaning_sangids = False
        params_t.scale_p = factor
        if (resample_masks == 0):
            path_of_interest = "results/" + params_t.dataset_name + "/" + "v_" + str(
                params_t.dataset_version) + "/or_subsampled"
        elif (resample_masks == 1):
            path_of_interest = "results/" + params_t.dataset_name + "/" + "v_" + str(
                params_t.dataset_version) + "/gt_subsampled"
        elif (resample_masks == 2):
            if (params_t.debug_cluster_unet_double):
                params_t.network_string = "METHOD"
            path_of_interest = "results/" + params_t.dataset_name + "/" + "v_" + str(
                params_t.dataset_version
            ) + "/" + params_t.network_string + "/mini_inst_only_subsampled"
            params_t.testing_mask = "results/" + params_t.dataset_name + "/" + "v_" + str(
                params_t.dataset_version
            ) + "/" + params_t.network_string + "/mini_inst_only"
        else:
            path_of_interest = ''

        data_volume, masks, V_or = tensors_io.load_data_and_masks(params_t)
        path_temp = ""
        for local_path in path_of_interest.split("/"):
            path_temp = path_temp + local_path
            if not os.path.isdir(path_temp):
                os.mkdir(path_temp)
            path_temp = path_temp + "/"

    if (resample_masks == 0):
        tensors_io.save_subvolume(data_volume, path_of_interest)
    else:
        tensors_io.save_subvolume_instances(data_volume * 0, masks,
                                            path_of_interest)
def clustering_algorithm(offset_vectors,
                         final_pred,
                         mask=None,
                         data_image=None,
                         params_t=None):
    print("Starting Clustering Algo with parameter r: {}".format(
        params_t.eps_param))
    params_t.min_points = 10
    # params_t.eps_param = 2
    params_t.mpp_min_r = params_t.mpp_min_r
    params_t.mpp_max_r = params_t.mpp_min_r

    # Get offset magnitudes
    magnitudes = torch.norm(offset_vectors, dim=1)  # [N_foreground]
    magnitudes = magnitudes / magnitudes.max()
    _, initial_indexes0 = torch.sort(magnitudes)  # [N_foreground]

    labels_array = torch.zeros(
        magnitudes.shape, device=offset_vectors.device, dtype=torch.long) - 1
    soft_labels_array = torch.zeros(
        magnitudes.shape, device=offset_vectors.device, dtype=torch.long) - 1
    soft_distances = torch.zeros(magnitudes.shape,
                                 device=offset_vectors.device,
                                 dtype=torch.float) + 10000
    # Get coordinates of real pixels
    coordinates_pixels = (
        final_pred[0, ...] == 1).nonzero().float()  # [N_foreground, 3]

    conected_component_labels = DBSCAN(eps=2, min_samples=4).fit_predict(
        coordinates_pixels.cpu())
    conected_component_labels = torch.from_numpy(conected_component_labels).to(
        offset_vectors.device)

    # Get coordinates of pixels with offset
    offset_pixels = (
        coordinates_pixels -
        offset_vectors).long().detach().float()  # [N_foreground, 3]

    # Get Center Proposals
    counter_image, initial_indexes = get_center_pixels(offset_pixels,
                                                       dims=params_t.cube_size)
    # Rescale magnitudes to make them probabilities
    embedded_volume_marks = torch.zeros(params_t.cube_size, params_t.cube_size,
                                        params_t.cube_size).detach().long()

    list_of_objects = []
    label = 1

    for pixel_idx in initial_indexes:
        off_index = offset_vectors[pixel_idx, :].long()
        if (labels_array[pixel_idx] > -1):
            continue
        if (embedded_volume_marks[off_index[0], off_index[1], off_index[2]] >
                0):
            labels_array[pixel_idx] = embedded_volume_marks[off_index[0],
                                                            off_index[1],
                                                            off_index[2]]
            continue

        neighbor_label, marks, absolute_fiber_indexes = propose_cluster4(
            pixel_idx, labels_array, offset_pixels, coordinates_pixels,
            embedded_volume_marks, list_of_objects, params_t)
        if (neighbor_label > -1):
            labels_array[pixel_idx] = neighbor_label
            continue

        if (marks is None):
            continue
        # label = expand_cluster3(marks, absolute_fiber_indexes, conected_component_labels, labels_offset, soft_distances, coordinates_pixels, offset_pixels, labels_array, soft_labels_array, embedded_volume_marks, label, params_t)
        label = expand_cluster4(marks, absolute_fiber_indexes,
                                conected_component_labels, soft_distances,
                                coordinates_pixels, labels_array,
                                soft_labels_array, embedded_volume_marks,
                                label, list_of_objects, params_t, data_image)

    # print("Refining")
    labels_array = refine_cluster4(list_of_objects, offset_pixels,
                                   coordinates_pixels, labels_array, params_t)
    # print("The number of non-classified pixel percentage is: {}".format(initial_indexes.shape[0] / (initial_indexes == -1).sum() * 100))

    #for el in list_of_objects:
    #    print(el.label)
    #    print(el.data_energy)
    #    print(el.prior_energy)
    #    print("")
    # labels_array = assign_unclassified_clusters(list_of_objects, initial_indexes, labels_array, coordinates_pixels)

    # for pixel_idx in initial_indexes:
    #    if(labels_array[pixel_idx] == - 1):
    #        labels_array[pixel_idx] = soft_distances[pixel_idx]
    '''
    DEBUG
    '''
    # labels_array = soft_labels_array
    if (True):
        from .evaluation import evaluate_iou_volume

        # Vectorize and make it numpy
        labels_offset = DBSCAN(eps=1.5, min_samples=4).fit_predict(
            offset_pixels.cpu()
        )  # DBSCAN(eps=1, min_samples=1).fit_predict(offset_pixels.cpu().numpy())
        labels_offset = torch.from_numpy(labels_offset).to(
            offset_pixels.device)
        coordinates_pixels = coordinates_pixels.long()
        offset_pixels = offset_pixels.long().clamp(0, params_t.cube_size - 1)

        # Create Embedded Volume
        embedded_volume = torch.zeros(params_t.cube_size, params_t.cube_size,
                                      params_t.cube_size).detach()
        for pix in offset_pixels.long():
            embedded_volume[pix[0], pix[1], pix[2]] += 1.0
        embedded_volume = embedded_volume / (0.01 * embedded_volume.max())
        tensors_io.save_subvolume(embedded_volume.unsqueeze(0),
                                  "debug_cluster/points")
        embedded_volume[coordinates_pixels.split(
            1, dim=1)] = magnitudes.cpu().unsqueeze(1)

        # Create Real Image
        embedded_volume_labels = torch.zeros(
            params_t.cube_size, params_t.cube_size,
            params_t.cube_size).detach().long()
        embedded_volume_labels[coordinates_pixels.split(
            1, dim=1)] = labels_array.unsqueeze(1).cpu() + 2

        embedded_volume_DBSCAN = torch.zeros(
            params_t.cube_size, params_t.cube_size,
            params_t.cube_size).detach().long()
        embedded_volume_DBSCAN[coordinates_pixels.split(
            1, dim=1)] = conected_component_labels.unsqueeze(1).cpu() + 2

        # Create Clustered Image
        space_clusters = torch.zeros_like(final_pred[0, ...])
        if (mask is not None):
            t_mask = mask.clone()
            t_mask = t_mask[0, 0, ...].to(space_clusters.device)
        else:
            t_mask = embedded_volume_marks

        space_clusters[offset_pixels.split(
            1, dim=1)] = t_mask[offset_pixels.long().split(1, dim=1)]

        space_clusters_MARKS = torch.zeros_like(final_pred[0, ...]).cpu()
        space_clusters_MARKS[coordinates_pixels.split(
            1, dim=1)] = labels_offset.unsqueeze(1).cpu() + 2

        # tensors_io.save_subplots(embedded_volume.unsqueeze(0) * 0, (space_clusters).long().unsqueeze(0).unsqueeze(0), torch.max((embedded_volume_marks).unsqueeze(0).cpu(), (0 * space_clusters).unsqueeze(0).unsqueeze(0).cpu()), t_mask.unsqueeze(0), (embedded_volume_labels).unsqueeze(0), "debug_cluster/side_to_side")
        tensors_io.save_subplots_6(
            data_image.unsqueeze(0),
            (space_clusters).long().unsqueeze(0).unsqueeze(0),
            (embedded_volume_marks).unsqueeze(0).cpu(),
            space_clusters_MARKS.long().unsqueeze(0).unsqueeze(0),
            t_mask.unsqueeze(0), (embedded_volume_labels).unsqueeze(0),
            embedded_volume_DBSCAN.unsqueeze(0), "debug_cluster/side_to_side")
        V = evaluate_iou_volume(embedded_volume_labels, t_mask.cpu())
        V = torch.from_numpy(V).unsqueeze(0)
        tensors_io.save_volume_h5(t_mask.cpu(),
                                  name='mask',
                                  directory='debug_cluster/h5')
        tensors_io.save_volume_h5(space_clusters.cpu(),
                                  name='space_clusters',
                                  directory='debug_cluster/h5')
        tensors_io.save_volume_h5(embedded_volume_marks,
                                  name='marks',
                                  directory='debug_cluster/h5')
        tensors_io.save_volume_h5(embedded_volume_labels,
                                  name='labels',
                                  directory='debug_cluster/h5')
        tensors_io.save_volume_h5(final_pred[0, ...].cpu(),
                                  name='seg_only',
                                  directory='debug_cluster/h5')
        tensors_io.save_volume_h5(counter_image.cpu(),
                                  name='counter_image',
                                  directory='debug_cluster/h5')
    return labels_array, embedded_volume_marks
    def save_quick_results(self,
                           mini_volume,
                           mini_seg,
                           mini_inst,
                           mini_gt,
                           seg_eval,
                           inst_eval,
                           masks=None,
                           final_clusters=None,
                           results_directory="results",
                           dataset_name="mini"):
        if (self.not_so_big):
            self.dataset_version = 4
        if (self.debug_cluster_unet_double):
            if (self.network_number == 4):
                self.network_string = "METHOD"
            else:
                self.network_string = "METHOD_double"

            if (self.train_dataset_number > 0):
                self.network_string = self.network_string + "_t_dt_" + str(
                    self.train_dataset_number)
        path = results_directory + "/" + self.dataset_name + "/" + "v_" + str(
            self.dataset_version) + "/" + self.network_string
        path_temp = ""
        for local_path in path.split("/"):
            path_temp = path_temp + local_path
            if not os.path.isdir(path_temp):
                os.mkdir(path_temp)
            path_temp = path_temp + "/"

        # mini_gt = 0 * mini_gt
        seg_f1 = int(seg_eval[2] * 1000)
        seg_f1 = float(seg_f1) / 1000

        inst_f1 = int(inst_eval[2] * 1000)
        inst_f1 = float(inst_f1) / 1000
        tensors_io.save_volume_h5(mini_inst[0, 0, ...].cpu(),
                                  name='fibers_only',
                                  directory=path + "/h5_files")
        tensors_io.save_subvolume_instances(
            mini_volume, mini_inst, path + "/" + dataset_name + "_inst")
        tensors_io.save_subvolume(
            mini_volume, results_directory + "/" + self.dataset_name + "/" +
            "v_" + str(self.dataset_version) + "/original")
        tensors_io.save_subvolume_instances(
            mini_volume, mini_gt, results_directory + "/" + self.dataset_name +
            "/" + "v_" + str(self.dataset_version) + "/gt")
        tensors_io.save_subvolume_instances(
            mini_volume * 0, mini_gt,
            results_directory + "/" + self.dataset_name + "/" + "v_" +
            str(self.dataset_version) + "/gt_only")

        if (self.save_side):
            tensors_io.save_subvolume_instances_side(
                mini_volume,
                mini_gt * 0,
                results_directory + "/" + self.dataset_name + "/" + "v_" +
                str(self.dataset_version) + "/original_side1",
                top=1)
            tensors_io.save_subvolume_instances_side(
                mini_volume,
                mini_gt * 0,
                results_directory + "/" + self.dataset_name + "/" + "v_" +
                str(self.dataset_version) + "/original_side2",
                top=2)
            tensors_io.save_subvolume_instances_side(
                mini_volume,
                mini_gt,
                results_directory + "/" + self.dataset_name + "/" + "v_" +
                str(self.dataset_version) + "/gt_side1",
                top=1)
            tensors_io.save_subvolume_instances_side(
                mini_volume,
                mini_gt,
                results_directory + "/" + self.dataset_name + "/" + "v_" +
                str(self.dataset_version) + "/gt_side2",
                top=2)

        tensors_io.save_subvolume_instances(mini_volume,
                                            mini_seg.cpu().long().cpu(),
                                            path + "/" + dataset_name + "_seg")
        tensors_io.save_subvolume_instances(
            mini_volume * 0, mini_seg.cpu(),
            path + "/" + dataset_name + "_seg_only")
        tensors_io.save_subvolume_instances(
            mini_volume * 0, mini_inst,
            path + "/" + dataset_name + "_inst_only")
        if (final_clusters is not None):
            tensors_io.save_subvolume_instances(
                mini_volume, final_clusters,
                path + "/" + dataset_name + "_clusters")
            tensors_io.save_subvolume_instances(
                mini_volume * 0, final_clusters,
                path + "/" + dataset_name + "_clusters_only")
            tensors_io.save_volume_h5((mini_volume[0, 0, ...] * 255).cpu(),
                                      name='og_im',
                                      directory=path + "/h5_files")
            tensors_io.save_volume_h5(final_clusters[0, 0, ...].cpu(),
                                      name='cluster_only',
                                      directory=path + "/h5_files")
            tensors_io.save_volume_h5(mini_gt[0, 0, ...].cpu(),
                                      name='mask_only',
                                      directory=path + "/h5_files")

        seg_results = []
        inst_results = []
        for i in range(3):
            v = int(seg_eval[i] * 1000)
            seg_results.append(float(v) / 1000)

            v = int(inst_eval[i] * 1000)
            inst_results.append(float(v) / 1000)

        file1 = open(path + "/results.txt", "w")
        str1 = "{},{},{}\n".format(seg_results[0], seg_results[1],
                                   seg_results[2])
        file1.write(str1)
        str1 = "{},{},{}\n".format(inst_results[0], inst_results[1],
                                   inst_results[2])
        file1.write(str1)
        file1.close()

        if (self.save_side):
            tensors_io.save_subvolume_instances_side(
                mini_volume,
                mini_inst,
                path + "/" + dataset_name + "_inst_side1",
                top=1)
            tensors_io.save_subvolume_instances_side(
                mini_volume,
                mini_inst,
                path + "/" + dataset_name + "_inst_side2",
                top=1)