Esempio n. 1
0
def custom_2head_dataloader(config):
  '''my custom dataloader to load custom images/data for unsupervisde clustering'''

  greyscale = True #for my custom data
  train_data_path = os.path.join(config.dataset_root, "train")
  test_val_data_path = os.path.join(config.dataset_root, "none")
  test_data_path = os.path.join(config.dataset_root, "none")
  assert (config.batchnorm_track)  # recommended (for test time invariance to batch size)

  # Transforms:
  if greyscale:
    tf1, tf2, tf3 = greyscale_make_transforms(config)
  else:
    tf1, tf2, tf3 = sobel_make_transforms(config)
Esempio n. 2
0
def cluster_twohead_create_YT_BB_dataloaders(config):
  assert (config.mode == "IID")
  assert (config.twohead)

  if config.dataset == "YT_BB":
    config.train_partitions_head_A = config.train_partition
    config.train_partitions_head_B = config.train_partitions_head_A

    config.mapping_assignment_partitions = config.assignment_partition
    config.mapping_test_partitions = config.test_partition

    dataset_class = YT_BB  #TODO YT_BB custom class

    # datasets produce either 2 or 5 channel images based on config.include_rgb
    tf1, tf2, tf3 = sobel_make_transforms(config)
  else:
    assert (False)

  print("Making datasets with YT_BB")
  sys.stdout.flush()

  dataloaders_head_A = \
    _create_dataloaders(config, dataset_class, tf1, tf2,
                        partition=config.train_partitions_head_A,
                       )

  dataloaders_head_B = \
    _create_dataloaders(config, dataset_class, tf1, tf2,
                        partition=config.train_partitions_head_B,
                       )

  mapping_assignment_dataloader = \
    _create_mapping_loader(config, dataset_class, tf3,
                           partition=config.mapping_assignment_partitions
                           )

  mapping_test_dataloader = \
    _create_mapping_loader(config, dataset_class, tf3,
                           partition=config.mapping_test_partitions
                          )

  return dataloaders_head_A, dataloaders_head_B, \
         mapping_assignment_dataloader, mapping_test_dataloader
Esempio n. 3
0
def cluster_twohead_create_dataloaders(config):
    assert (config.mode == "IID")
    assert (config.twohead)

    target_transform = None

    if "CIFAR" in config.dataset:
        config.train_partitions_head_A = [True, False]
        config.train_partitions_head_B = config.train_partitions_head_A

        config.mapping_assignment_partitions = [True, False]
        config.mapping_test_partitions = [True, False]

        if config.dataset == "CIFAR10":
            dataset_class = torchvision.datasets.CIFAR10
        elif config.dataset == "CIFAR100":
            dataset_class = torchvision.datasets.CIFAR100
        elif config.dataset == "CIFAR20":
            dataset_class = torchvision.datasets.CIFAR100
            target_transform = _cifar100_to_cifar20
        else:
            assert (False)

        # datasets produce either 2 or 5 channel images based on config.include_rgb
        tf1, tf2, tf3 = sobel_make_transforms(config)

    elif config.dataset == "STL10":
        assert (config.mix_train)
        if not config.stl_leave_out_unlabelled:
            print("adding unlabelled data for STL10")
            config.train_partitions_head_A = ["train+unlabeled", "test"]
        else:
            print("not using unlabelled data for STL10")
            config.train_partitions_head_A = ["train", "test"]

        config.train_partitions_head_B = ["train", "test"]

        config.mapping_assignment_partitions = ["train", "test"]
        config.mapping_test_partitions = ["train", "test"]

        dataset_class = torchvision.datasets.STL10

        # datasets produce either 2 or 5 channel images based on config.include_rgb
        tf1, tf2, tf3 = sobel_make_transforms(config)

    elif config.dataset == "MNIST":
        config.train_partitions_head_A = [True, False]
        config.train_partitions_head_B = config.train_partitions_head_A

        config.mapping_assignment_partitions = [True, False]
        config.mapping_test_partitions = [True, False]

        dataset_class = torchvision.datasets.MNIST

        tf1, tf2, tf3 = greyscale_make_transforms(config)

    else:
        assert (False)

    print("Making datasets with %s and %s" % (dataset_class, target_transform))
    sys.stdout.flush()

    dataloaders_head_A = \
      _create_dataloaders(config, dataset_class, tf1, tf2,
                          partitions=config.train_partitions_head_A,
                          target_transform=target_transform)

    dataloaders_head_B = \
      _create_dataloaders(config, dataset_class, tf1, tf2,
                          partitions=config.train_partitions_head_B,
                          target_transform=target_transform)

    mapping_assignment_dataloader = \
      _create_mapping_loader(config, dataset_class, tf3,
                             partitions=config.mapping_assignment_partitions,
                             target_transform=target_transform)

    mapping_test_dataloader = \
      _create_mapping_loader(config, dataset_class, tf3,
                             partitions=config.mapping_test_partitions,
                             target_transform=target_transform)

    return dataloaders_head_A, dataloaders_head_B, \
           mapping_assignment_dataloader, mapping_test_dataloader
Esempio n. 4
0
def cluster_create_dataloaders(config):
    assert (config.mode == "IID+")
    assert (not config.twohead)

    target_transform = None

    # separate train/test sets
    if "CIFAR" in config.dataset:
        config.train_partitions = [True]
        config.mapping_assignment_partitions = [True]
        config.mapping_test_partitions = [False]

        if config.dataset == "CIFAR10":
            dataset_class = torchvision.datasets.CIFAR10
        elif config.dataset == "CIFAR100":
            dataset_class = torchvision.datasets.CIFAR100
        elif config.dataset == "CIFAR20":
            dataset_class = torchvision.datasets.CIFAR100
            target_transform = _cifar100_to_cifar20
        else:
            assert (False)

        # datasets produce either 2 or 5 channel images based on config.include_rgb
        tf1, tf2, tf3 = sobel_make_transforms(config)

    elif config.dataset == "STL10":
        config.train_partitions = ["train+unlabeled"]
        config.mapping_assignment_partitions = ["train"]
        config.mapping_test_partitions = ["test"]

        dataset_class = torchvision.datasets.STL10

        # datasets produce either 2 or 5 channel images based on config.include_rgb
        tf1, tf2, tf3 = sobel_make_transforms(config)

    elif config.dataset == "MNIST":
        config.train_partitions = [True]
        config.mapping_assignment_partitions = [True]
        config.mapping_test_partitions = [False]

        dataset_class = torchvision.datasets.MNIST

        tf1, tf2, tf3 = greyscale_make_transforms(config)

    else:
        assert (False)

    print("Making datasets with %s and %s" % (dataset_class, target_transform))
    sys.stdout.flush()

    dataloaders = \
      _create_dataloaders(config, dataset_class, tf1, tf2,
                          partitions=config.train_partitions,
                          target_transform=target_transform)

    mapping_assignment_dataloader = \
      _create_mapping_loader(config, dataset_class, tf3,
                             partitions=config.mapping_assignment_partitions,
                             target_transform=target_transform)

    mapping_test_dataloader = \
      _create_mapping_loader(config, dataset_class, tf3,
                             partitions=config.mapping_test_partitions,
                             target_transform=target_transform)

    return dataloaders, mapping_assignment_dataloader, mapping_test_dataloader
    assert (config.train_partitions == ["train+unlabeled"])
    assert (config.mapping_assignment_partitions == ["train"])
    assert (config.mapping_test_partitions == ["test"])

# append to old results
if not hasattr(config, "assign_set_szs_pc_acc") or given_config.rewrite:
    print("resetting config.assign_set_szs_pc_acc to empty")
    config.assign_set_szs_pc_acc = {}

for pc in new_assign_set_szs_pc:
    print("doing %f" % pc)
    sysout.flush()

    # datasets produce either 2 or 5 channel images based on config.include_rgb
    tf1, tf2, tf3 = sobel_make_transforms(config,
                                          cutout=config.cutout,
                                          cutout_p=config.cutout_p,
                                          cutout_max_box=config.cutout_max_box)

    if config.dataset == "STL10":
        dataloaders, mapping_assignment_dataloader, mapping_test_dataloader = \
          make_STL_data(config, tf1, tf2, tf3, truncate_assign=True, truncate_pc=pc)
    elif (config.dataset == "CIFAR10") or (config.dataset == "CIFAR100") or \
      (config.dataset == "CIFAR20"):
        dataloaders, mapping_assignment_dataloader, mapping_test_dataloader = \
          make_CIFAR_data(config, tf1, tf2, tf3, truncate_assign=True,
                          truncate_pc=pc)
    else:
        print(config.dataset)
        assert (False)

    num_train_batches = len(dataloaders[0])
Esempio n. 6
0
File: data.py Progetto: xshirade/IIC
def create_basic_clustering_dataloaders(config):
    """
  My original data loading code is complex to cover all my experiments. Here is a simple version.
  Use it to replace cluster_twohead_create_dataloaders() in the scripts.
  
  This uses ImageFolder but you could use your own subclass of torch.utils.data.Dataset.
  (ImageFolder data is not shuffled so an ideally deterministic random sampler is needed.)
  
  :param config: Requires num_dataloaders and values used by *make_transforms(), e.g. crop size, 
  input size etc.
  :return: Training and testing dataloaders
  """

    # Change these according to your data:
    greyscale = False
    train_data_path = os.path.join(config.dataset_root, "train")
    test_val_data_path = os.path.join(config.dataset_root, "none")
    test_data_path = os.path.join(config.dataset_root, "none")
    assert (config.batchnorm_track
            )  # recommended (for test time invariance to batch size)

    # Transforms:
    if greyscale:
        tf1, tf2, tf3 = greyscale_make_transforms(config)
    else:
        tf1, tf2, tf3 = sobel_make_transforms(config)

    # Training data:
    # main output head (B), auxiliary overclustering head (A), same data for both
    dataloaders_head_B = [torch.utils.data.DataLoader(
      torchvision.datasets.ImageFolder(root=train_data_path, transform=tf1),
      batch_size=config.dataloader_batch_sz,
      shuffle=False,
      sampler=DeterministicRandomSampler(),
      num_workers=0,
      drop_last=False)] + \
                         [torch.utils.data.DataLoader(
                           torchvision.datasets.ImageFolder(root=train_data_path, transform=tf2),
                           batch_size=config.dataloader_batch_sz,
                           shuffle=False,
                           sampler=DeterministicRandomSampler(),
                           num_workers=0,
                           drop_last=False) for _ in range(config.num_dataloaders)]

    dataloaders_head_A = [torch.utils.data.DataLoader(
      torchvision.datasets.ImageFolder(root=train_data_path, transform=tf1),
      batch_size=config.dataloader_batch_sz,
      shuffle=False,
      sampler=DeterministicRandomSampler(),
      num_workers=0,
      drop_last=False)] + \
                         [torch.utils.data.DataLoader(
                           torchvision.datasets.ImageFolder(root=train_data_path, transform=tf2),
                           batch_size=config.dataloader_batch_sz,
                           shuffle=False,
                           sampler=DeterministicRandomSampler(),
                           num_workers=0,
                           drop_last=False) for _ in range(config.num_dataloaders)]

    # Testing data (labelled):
    mapping_assignment_dataloader, mapping_test_dataloader = None, None
    if os.path.exists(test_data_path):
        mapping_assignment_dataloader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(test_val_data_path,
                                             transform=tf3),
            batch_size=config.batch_sz,
            shuffle=False,
            sampler=DeterministicRandomSampler(),
            num_workers=0,
            drop_last=False)

        mapping_test_dataloader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(test_data_path, transform=tf3),
            batch_size=config.batch_sz,
            shuffle=False,
            sampler=DeterministicRandomSampler(),
            num_workers=0,
            drop_last=False)

    return dataloaders_head_A, dataloaders_head_B, \
           mapping_assignment_dataloader, mapping_test_dataloader
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_ind", type=int, required=True)

    parser.add_argument("--arch", type=str, required=True)

    parser.add_argument("--head_lr", type=float, required=True)
    parser.add_argument("--trunk_lr", type=float, required=True)

    parser.add_argument("--num_epochs", type=int, default=3200)

    parser.add_argument("--new_batch_sz", type=int, default=-1)

    parser.add_argument("--old_model_ind", type=int, required=True)

    parser.add_argument("--penultimate_features",
                        default=False,
                        action="store_true")

    parser.add_argument("--random_affine", default=False, action="store_true")
    parser.add_argument("--affine_p", type=float, default=0.5)

    parser.add_argument("--cutout", default=False, action="store_true")
    parser.add_argument("--cutout_p", type=float, default=0.5)
    parser.add_argument("--cutout_max_box", type=float, default=0.5)

    parser.add_argument("--restart", default=False, action="store_true")
    parser.add_argument("--lr_schedule", type=int, nargs="+", default=[])
    parser.add_argument("--lr_mult", type=float, default=0.5)

    parser.add_argument("--restart_new_model_ind",
                        default=False,
                        action="store_true")
    parser.add_argument("--new_model_ind", type=int, default=0)

    parser.add_argument("--out_root",
                        type=str,
                        default="/scratch/shared/slow/xuji/iid_private")
    config = parser.parse_args()  # new config

    # Setup ----------------------------------------------------------------------

    config.contiguous_sz = 10  # Tencrop
    config.out_dir = os.path.join(config.out_root, str(config.model_ind))

    if not os.path.exists(config.out_dir):
        os.makedirs(config.out_dir)

    if config.restart:
        given_config = config
        reloaded_config_path = os.path.join(given_config.out_dir,
                                            "config.pickle")
        print("Loading restarting config from: %s" % reloaded_config_path)
        with open(reloaded_config_path, "rb") as config_f:
            config = pickle.load(config_f)
        assert (config.model_ind == given_config.model_ind)

        config.restart = True
        config.num_epochs = given_config.num_epochs  # train for longer

        config.restart_new_model_ind = given_config.restart_new_model_ind
        config.new_model_ind = given_config.new_model_ind

        start_epoch = config.last_epoch + 1

        print("...restarting from epoch %d" % start_epoch)

        # in case we overshot without saving
        config.epoch_acc = config.epoch_acc[:start_epoch]
        config.epoch_loss = config.epoch_loss[:start_epoch]

    else:
        config.epoch_acc = []
        config.epoch_loss = []
        start_epoch = 0

    # old config only used retrospectively for setting up model at start
    reloaded_config_path = os.path.join(
        os.path.join(config.out_root, str(config.old_model_ind)),
        "config.pickle")
    print("Loading old features config from: %s" % reloaded_config_path)
    with open(reloaded_config_path, "rb") as config_f:
        old_config = pickle.load(config_f)
        assert (old_config.model_ind == config.old_model_ind)

    if config.new_batch_sz == -1:
        config.new_batch_sz = old_config.batch_sz

    fig, axarr = plt.subplots(2, sharex=False, figsize=(20, 20))

    # Data -----------------------------------------------------------------------

    assert (old_config.dataset == "STL10")

    # make supervised data: train on train, test on test, unlabelled is unused
    tf1, tf2, tf3 = sobel_make_transforms(old_config,
                                          random_affine=config.random_affine,
                                          cutout=config.cutout,
                                          cutout_p=config.cutout_p,
                                          cutout_max_box=config.cutout_max_box,
                                          affine_p=config.affine_p)

    dataset_class = torchvision.datasets.STL10
    train_data = dataset_class(
        root=old_config.dataset_root,
        transform=tf2,  # also could use tf1
        split="train")

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.new_batch_sz,
                                               shuffle=True,
                                               num_workers=0,
                                               drop_last=False)

    test_data = dataset_class(root=old_config.dataset_root,
                              transform=None,
                              split="test")
    test_data = TenCropAndFinish(test_data,
                                 input_sz=old_config.input_sz,
                                 include_rgb=old_config.include_rgb)

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=config.new_batch_sz,
        # full batch
        shuffle=False,
        num_workers=0,
        drop_last=False)

    # Model ----------------------------------------------------------------------

    net_features = archs.__dict__[old_config.arch](old_config)

    if not config.restart:
        model_path = os.path.join(old_config.out_dir, "best_net.pytorch")
        net_features.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))

    dlen = get_dlen(net_features,
                    train_loader,
                    include_rgb=old_config.include_rgb,
                    penultimate_features=config.penultimate_features)
    print("dlen: %d" % dlen)

    assert (config.arch == "SupHead5")
    net = SupHead5(net_features, dlen=dlen, gt_k=old_config.gt_k)

    if config.restart:
        print("restarting from latest net")
        model_path = os.path.join(config.out_dir, "latest_net.pytorch")
        net.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))

    net.cuda()
    net = torch.nn.DataParallel(net)

    opt_trunk = torch.optim.Adam(net.module.trunk.parameters(),
                                 lr=config.trunk_lr)
    opt_head = torch.optim.Adam(net.module.head.parameters(),
                                lr=(config.head_lr))

    if config.restart:
        print("restarting from latest optimiser")
        optimiser_states = torch.load(
            os.path.join(config.out_dir, "latest_optimiser.pytorch"))
        opt_trunk.load_state_dict(optimiser_states["opt_trunk"])
        opt_head.load_state_dict(optimiser_states["opt_head"])
    else:
        print("using new optimiser state")

    criterion = nn.CrossEntropyLoss().cuda()

    if not config.restart:
        net.eval()
        acc = assess_acc_block(
            net,
            test_loader,
            gt_k=old_config.gt_k,
            include_rgb=old_config.include_rgb,
            penultimate_features=config.penultimate_features,
            contiguous_sz=config.contiguous_sz)

        print("pre: model %d old model %d, acc %f time %s" %
              (config.model_ind, config.old_model_ind, acc, datetime.now()))
        sys.stdout.flush()

        config.epoch_acc.append(acc)

    if config.restart_new_model_ind:
        assert (config.restart)
        config.model_ind = config.new_model_ind  # old_model_ind stays same
        config.out_dir = os.path.join(config.out_root, str(config.model_ind))
        print("restarting as model %d" % config.model_ind)

        if not os.path.exists(config.out_dir):
            os.makedirs(config.out_dir)

    # Train ----------------------------------------------------------------------

    for e_i in xrange(start_epoch, config.num_epochs):
        net.train()

        if e_i in config.lr_schedule:
            print("e_i %d, multiplying lr for opt trunk and head by %f" %
                  (e_i, config.lr_mult))
            opt_trunk = update_lr(opt_trunk, lr_mult=config.lr_mult)
            opt_head = update_lr(opt_head, lr_mult=config.lr_mult)
            if not hasattr(config, "lr_changes"):
                config.lr_changes = []
            config.lr_changes.append((e_i, config.lr_mult))

        avg_loss = 0.
        num_batches = len(train_loader)
        for i, (imgs, targets) in enumerate(train_loader):
            imgs = sobel_process(imgs.cuda(), old_config.include_rgb)
            targets = targets.cuda()

            x_out = net(imgs, penultimate_features=config.penultimate_features)
            loss = criterion(x_out, targets)

            avg_loss += float(loss.data)

            opt_trunk.zero_grad()
            opt_head.zero_grad()

            loss.backward()

            opt_trunk.step()
            opt_head.step()

            if (i % 100 == 0) or (e_i == start_epoch):
                print("batch %d of %d, loss %f, time %s" %
                      (i, num_batches, float(loss.data), datetime.now()))
                sys.stdout.flush()

        avg_loss /= num_batches

        net.eval()
        acc = assess_acc_block(
            net,
            test_loader,
            gt_k=old_config.gt_k,
            include_rgb=old_config.include_rgb,
            penultimate_features=config.penultimate_features,
            contiguous_sz=config.contiguous_sz)

        print(
            "model %d old model %d epoch %d acc %f time %s" %
            (config.model_ind, config.old_model_ind, e_i, acc, datetime.now()))
        sys.stdout.flush()

        is_best = False
        if acc > max(config.epoch_acc):
            is_best = True

        config.epoch_acc.append(acc)
        config.epoch_loss.append(avg_loss)

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("Acc")

        axarr[1].clear()
        axarr[1].plot(config.epoch_loss)
        axarr[1].set_title("Loss")

        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % 10 == 0):
            net.module.cpu()

            if is_best:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "best_net.pytorch"))
                torch.save(
                    {
                        "opt_head": opt_head.state_dict(),
                        "opt_trunk": opt_trunk.state_dict()
                    }, os.path.join(config.out_dir, "best_optimiser.pytorch"))

            # save model sparingly for this script
            if e_i % 10 == 0:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "latest_net.pytorch"))
                torch.save(
                    {
                        "opt_head": opt_head.state_dict(),
                        "opt_trunk": opt_trunk.state_dict()
                    }, os.path.join(config.out_dir,
                                    "latest_optimiser.pytorch"))

            net.module.cuda()

            config.last_epoch = e_i  # for last saved version

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'w') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)
Esempio n. 8
0
def make_triplets_data(config):
  target_transform = None

  if "CIFAR" in config.dataset:
    config.train_partitions_head_A = [True, False]
    config.train_partitions_head_B = config.train_partitions_head_A

    config.mapping_assignment_partitions = [True, False]
    config.mapping_test_partitions = [True, False]

    if config.dataset == "CIFAR10":
      dataset_class = torchvision.datasets.CIFAR10
    elif config.dataset == "CIFAR100":
      dataset_class = torchvision.datasets.CIFAR100
    elif config.dataset == "CIFAR20":
      dataset_class = torchvision.datasets.CIFAR100
      target_transform = _cifar100_to_cifar20
    else:
      assert (False)

    # datasets produce either 2 or 5 channel images based on config.include_rgb
    tf1, tf2, tf3 = sobel_make_transforms(config)

  elif config.dataset == "STL10":
    assert (config.mix_train)
    if not config.stl_leave_out_unlabelled:
      print("adding unlabelled data for STL10")
      config.train_partitions_head_A = ["train+unlabeled", "test"]
    else:
      print("not using unlabelled data for STL10")
      config.train_partitions_head_A = ["train", "test"]

    config.train_partitions_head_B = ["train", "test"]

    config.mapping_assignment_partitions = ["train", "test"]
    config.mapping_test_partitions = ["train", "test"]

    dataset_class = torchvision.datasets.STL10

    # datasets produce either 2 or 5 channel images based on config.include_rgb
    tf1, tf2, tf3 = sobel_make_transforms(config)

  elif config.dataset == "MNIST":
    config.train_partitions_head_A = [True, False]
    config.train_partitions_head_B = config.train_partitions_head_A

    config.mapping_assignment_partitions = [True, False]
    config.mapping_test_partitions = [True, False]

    dataset_class = torchvision.datasets.MNIST

    tf1, tf2, tf3 = greyscale_make_transforms(config)

  else:
    assert (False)

  dataloaders = \
    _create_dataloaders(config, dataset_class, tf1, tf2,
                        partitions=config.train_partitions_head_A,
                        target_transform=target_transform)

  dataloader_original = dataloaders[0]
  dataloader_positive = dataloaders[1]

  shuffled_dataloaders = \
    _create_dataloaders(config, dataset_class, tf1, tf2,
                        partitions=config.train_partitions_head_A,
                        target_transform=target_transform,
                        shuffle=True)

  dataloader_negative = shuffled_dataloaders[0]

  # since this is fully unsupervised, assign dataloader = test dataloader
  dataloader_test = \
    _create_mapping_loader(config, dataset_class, tf3,
                           partitions=config.mapping_test_partitions,
                           target_transform=target_transform)

  return dataloader_original, dataloader_positive, dataloader_negative, \
         dataloader_test
Esempio n. 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir", type=str, default="./out")

    given_config = parser.parse_args()

    given_config.out_dir = given_config.save_dir

    reloaded_config_path = os.path.join(given_config.out_dir, "config.pickle")
    print("Loading restarting config from: %s" % reloaded_config_path)
    with open(reloaded_config_path, "rb") as config_f:
        config = pickle.load(config_f)
        print(config)

    if not hasattr(config, "twohead"):
        config.twohead = ("TwoHead" in config.arch)

    config.double_eval = False  # no double eval, not training (or saving config)

    net = archs.__dict__[config.arch](config)
    model_path = os.path.join(given_config.out_dir, "latest_net.pytorch")
    net.load_state_dict(
        torch.load(model_path, map_location=lambda storage, loc: storage))
    net.cuda()
    net = torch.nn.DataParallel(net)
    net.eval()

    # print(net)
    tf1, tf2, tf3 = sobel_make_transforms(config)
    # Pass each image in net to get cluster prediction
    # print(tf1)
    # print(tf2)
    # print(tf3)

    # dataset = torchvision.datasets.CIFAR10(config.dataset_root, train=True, transform=tf3)
    dataset = ImageNetDS(config.dataset_root, 32, train=True, transform=tf3)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=config.batch_sz,
                                             shuffle=False)

    flat_predss_all, flat_targets_all, = _clustering_get_data(config,
                                                              net,
                                                              dataloader,
                                                              sobel=True,
                                                              using_IR=False,
                                                              verbose=True)
    print(len(flat_predss_all), flat_predss_all[0].shape)
    print("flat_targets_all.shape", flat_targets_all.shape)
    print(config.num_sub_heads)

    # visualize each cluster
    # view_dataset = torchvision.datasets.CIFAR10(config.dataset_root, train=True,
    #                                             transform=torchvision.transforms.ToTensor())
    view_dataset = ImageNetDS(config.dataset_root,
                              32,
                              train=True,
                              transform=torchvision.transforms.ToTensor())

    # cluster_labels = flat_predss_all[0].cpu().numpy()
    actual_labels = flat_targets_all.cpu().numpy()

    for head in range(
            config.num_sub_heads):  # each head is a clustering prediction

        cluster_labels = flat_predss_all[head].cpu().numpy()

        for c in range(config.output_k_B):
            cluster_indices = np.where(cluster_labels == c)[0]
            # gt_indices = np.where(actual_labels == c)[0]
            print("HEAD {}: cluster {} have {} images".format(
                head, c, len(cluster_indices)))
            c_dataloader = torch.utils.data.DataLoader(
                view_dataset,
                batch_size=64,
                shuffle=False,
                sampler=SubsetRandomSampler(cluster_indices))
            # gt_dataloader = torch.utils.data.DataLoader(view_dataset, batch_size=64, shuffle=False,
            #                                             sampler=SubsetRandomSampler(gt_indices))

            for (images, targets) in c_dataloader:
                print("saving cluster {}".format(c), images.shape)
                torchvision.utils.save_image(
                    images, 'head{}-cluster{}.png'.format(head, c))
                break

            # for (images, targets) in gt_dataloader:
            #     print("sanity check gt classes {}".format(c), images.shape)
            #     torchvision.utils.save_image(images, 'gt{}.png'.format(c))
            #     break

        # save the cluster labels as before
        save = {'label': cluster_labels}
        filename = 'iic-k10-head{}-cluster-model{}.pickle'.format(
            head, config.model_ind)
        with open(filename, 'wb') as f:
            pickle.dump(save, f, protocol=pickle.HIGHEST_PROTOCOL)
        print("saved iic unsupervised cluster to {}!".format(filename))