def triplets_eval(config, net, dataloader_test, sobel):
    net.eval()

    if not config.kmeans_on_features:
        flat_preds_all, flat_targets_all = triplets_get_data(
            config, net, dataloader_test, sobel)
        assert (config.output_k == config.gt_k)
    else:
        flat_preds_all, flat_targets_all = triplets_get_data_kmeans_on_features(
            config, net, dataloader_test, sobel)

    num_samples = flat_preds_all.shape[0]
    assert (num_samples == flat_targets_all.shape[0])

    net.train()

    match = _hungarian_match(flat_preds_all,
                             flat_targets_all,
                             preds_k=config.gt_k,
                             targets_k=config.gt_k)

    found = torch.zeros(config.gt_k)  # sanity
    reordered_preds = torch.zeros(num_samples,
                                  dtype=flat_preds_all.dtype).cuda()

    for pred_i, target_i in match:
        reordered_preds[flat_preds_all == pred_i] = target_i
        found[pred_i] = 1

    assert (found.sum() == config.gt_k)  # each class must get mapped

    mass = np.zeros((1, config.gt_k))
    per_class_acc = np.zeros((1, config.gt_k))
    for c in range(config.gt_k):
        flags = (reordered_preds == c)
        actual = (flat_targets_all == c)
        mass[0, c] = flags.sum().item()
        per_class_acc[0, c] = (flags * actual).sum().item()

    acc = _acc(reordered_preds, flat_targets_all, config.gt_k)

    is_best = (len(config.epoch_acc) > 0) and (acc > max(config.epoch_acc))
    config.epoch_acc.append(acc)

    if config.masses is None:
        assert (config.per_class_acc is None)
        config.masses = mass
        config.per_class_acc = per_class_acc
    else:
        config.masses = np.concatenate((config.masses, mass), axis=0)
        config.per_class_acc = np.concatenate(
            (config.per_class_acc, per_class_acc), axis=0)

    return is_best
Exemple #2
0
def apply_trained_kmeans(config, net, test_dataloader, kmeans):
    if config.verbose:
        print("starting inference")
        sysout.flush()

    # on the entire test dataset
    num_imgs = len(test_dataloader.dataset)
    max_num_samples = num_imgs * config.input_sz * config.input_sz
    preds_all = torch.zeros(max_num_samples, dtype=torch.int32).cuda()
    targets_all = torch.zeros(max_num_samples, dtype=torch.int32).cuda()

    actual_num_unmasked = 0

    # discard the label information in the dataloader
    for i, tup in enumerate(test_dataloader):
        if (config.verbose and i < 10) or (i % int(len(test_dataloader) / 10) == 0):
            print("(apply_trained_kmeans) batch %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

        imgs, targets, mask = tup  # test dataloader, cpu tensors
        imgs, mask_cuda, targets, mask_np = imgs.cuda(), mask.cuda(), \
            targets.cuda(), mask.numpy().astype(
            np.bool)
        num_unmasked = mask_cuda.sum().item()

        if not config.no_sobel:
            imgs = sobel_process(imgs, config.include_rgb,
                                 using_IR=config.using_IR)
            # now rgb(ir) and/or sobel

        with torch.no_grad():
            # penultimate = features
            x_out = net(imgs, penultimate=True).cpu().numpy()

        x_out = x_out.transpose((0, 2, 3, 1))  # features last
        x_out = x_out[mask_np, :]
        targets = targets.masked_select(mask_cuda)  # can do because flat

        assert (x_out.shape == (num_unmasked, net.module.features_sz))
        preds = torch.from_numpy(kmeans.predict(x_out)).cuda()

        preds_all[actual_num_unmasked: actual_num_unmasked +
                  num_unmasked] = preds
        targets_all[
            actual_num_unmasked: actual_num_unmasked + num_unmasked] = targets

        actual_num_unmasked += num_unmasked

    preds_all = preds_all[:actual_num_unmasked]
    targets_all = targets_all[:actual_num_unmasked]

    torch.cuda.empty_cache()

    # permutation, not many-to-one
    match = _hungarian_match(preds_all, targets_all, preds_k=config.gt_k,
                             targets_k=config.gt_k)
    torch.cuda.empty_cache()

    # do in cpu because of RAM
    reordered_preds = torch.zeros(actual_num_unmasked, dtype=preds_all.dtype)
    for pred_i, target_i in match:
        selected = (preds_all == pred_i).cpu()
        reordered_preds[selected] = target_i

    reordered_preds = reordered_preds.cuda()

    # this checks values
    acc = _acc(reordered_preds, targets_all,
               config.gt_k, verbose=config.verbose)

    if GET_NMI_ARI:
        nmi, ari = _nmi(reordered_preds, targets_all), \
            _ari(reordered_preds, targets_all)
    else:
        nmi, ari = -1., -1.

    reordered_masses = np.zeros(config.gt_k)
    for c in range(config.gt_k):
        reordered_masses[c] = float(
            (reordered_preds == c).sum()) / actual_num_unmasked

    return acc, nmi, ari, reordered_masses
Exemple #3
0
def main():
    # based on segmentation_multioutput_twohead - we pass in the config of the
    # IID run we are comparing against, so the settings can be copied

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_ind", type=int, required=True)
    parser.add_argument("--out_root",
                        type=str,
                        default="/scratch/shared/slow/xuji/iid_private")
    parser.add_argument("--IID_model_ind", type=int, required=True)
    parser.add_argument("--max_num_train", type=int, required=True)
    parser.add_argument("--test_code", default=False, action="store_true")
    parser.add_argument("--do_sift", default=False, action="store_true")

    config = parser.parse_args()
    config.out_dir = os.path.join(config.out_root, str(config.model_ind))
    if not os.path.exists(config.out_dir):
        os.makedirs(config.out_dir)

    archetype_config_path = os.path.join(config.out_root,
                                         str(config.IID_model_ind),
                                         "config.pickle")
    print("Loading archetype config from: %s" % archetype_config_path)
    with open(archetype_config_path, "rb") as config_f:
        archetype_config = pickle.load(config_f)
    assert (config.IID_model_ind == archetype_config.model_ind)
    assert (archetype_config.mode == "IID")  # compare against fully unsup

    sample_fn = _get_vectorised_colour_samples
    if config.do_sift:
        sample_fn = _get_vectorised_sift_samples

    # set it to be only rgb (and ir if nec) but no sobel - we're clustering
    # single pixel colours
    archetype_config.include_rgb = True
    archetype_config.no_sobel = True
    if "Coco" in archetype_config.dataset:
        assert (not archetype_config.using_IR)
        archetype_config.in_channels = 3
    elif archetype_config.dataset == "Potsdam":  # IR
        assert (archetype_config.using_IR)
        archetype_config.in_channels = 4

    # Data
    # -------------------------------------------------------------------------
    if "Coco" in archetype_config.dataset:
        dataloaders_head_A, mapping_assignment_dataloader, \
        mapping_test_dataloader = \
            make_Coco_dataloaders(archetype_config)

    elif archetype_config.dataset == "Potsdam":
        dataloaders_head_A, mapping_assignment_dataloader, \
        mapping_test_dataloader = \
            make_Potsdam_dataloaders(archetype_config)
    else:
        raise NotImplementedError

    # unlike in clustering script for STL - isn't any data from unknown classes
    dataloaders_head_B = dataloaders_head_A

    # networks and optimisers
    # ------------------------------------------------------
    assert (archetype_config.num_dataloaders == 1)
    dataloader = dataloaders_head_B[0]

    samples = sample_fn(archetype_config, dataloader)
    print("got training samples")
    sys.stdout.flush()

    if config.test_code:
        print("testing src, taking 10000 samples only")
        samples = samples[:10000, :]
    else:
        num_samples_train = min(samples.shape[0], config.max_num_train)
        print("taking %d samples" % num_samples_train)
        chosen_inds = np.random.choice(samples.shape[0],
                                       size=num_samples_train,
                                       replace=False)
        samples = samples[chosen_inds, :]
        print(samples.shape)
    sys.stdout.flush()

    kmeans = MiniBatchKMeans(n_clusters=archetype_config.gt_k,
                             verbose=1).fit(samples)
    print("trained kmeans")
    sys.stdout.flush()

    # use mapping assign to assign output_k=gt_k to gt_k
    # and also assess on its predictions, since it's identical to
    # mapping_test_dataloader
    assign_samples, assign_labels = sample_fn(archetype_config,
                                              mapping_assignment_dataloader)
    num_samples = assign_samples.shape[0]
    assign_preds = kmeans.predict(assign_samples)
    print("finished prediction for mapping assign/test data")
    sys.stdout.flush()

    assign_preds = torch.from_numpy(assign_preds).cuda()
    assign_labels = torch.from_numpy(assign_labels).cuda()

    if archetype_config.eval_mode == "hung":
        match = _hungarian_match(assign_preds,
                                 assign_labels,
                                 preds_k=archetype_config.gt_k,
                                 targets_k=archetype_config.gt_k)
    elif archetype_config.eval_mode == "orig":  # flat!
        match = _original_match(assign_preds,
                                assign_labels,
                                preds_k=archetype_config.gt_k,
                                targets_k=archetype_config.gt_k)
    elif archetype_config.eval_mode == "orig_soft":
        assert (False)  # not used

    # reorder predictions to be same cluster assignments as gt_k
    found = torch.zeros(archetype_config.gt_k)
    reordered_preds = torch.zeros(num_samples).to(torch.int32).cuda()
    for pred_i, target_i in match:
        reordered_preds[assign_preds == pred_i] = target_i
        found[pred_i] = 1
    assert (found.sum() == archetype_config.gt_k
            )  # each output_k must get mapped

    acc = _acc(reordered_preds, assign_labels, archetype_config.gt_k)

    print("got acc %f" % acc)
    config.epoch_acc = [acc]
    config.centroids = kmeans.cluster_centers_
    config.match = match

    # write results and centroids to model_ind output file
    with open(os.path.join(config.out_dir, "config.pickle"), "w") as outfile:
        pickle.dump(config, outfile)

    with open(os.path.join(config.out_dir, "config.txt"), "w") as text_file:
        text_file.write("%s" % config)