コード例 #1
0
def main():
    # Retrieve config file
    p = create_config(args)
    print(colored(p, 'red'))

    # Model
    print(colored('Retrieve model', 'blue'))
    model = get_model(p)
    print('Model is {}'.format(model.__class__.__name__))
    print(model)
    model = torch.nn.DataParallel(model)
    model = model.cuda()

    # CUDNN
    print(colored('Set CuDNN benchmark', 'blue'))
    torch.backends.cudnn.benchmark = True

    # Dataset
    print(colored('Retrieve dataset', 'blue'))
    transforms = get_val_transformations(p)
    train_dataset = get_train_dataset(p, transforms)
    val_dataset = get_val_dataset(p, transforms)
    train_dataloader = get_val_dataloader(p, train_dataset)
    val_dataloader = get_val_dataloader(p, val_dataset)
    print('Dataset contains {}/{} train/val samples'.format(
        len(train_dataset), len(val_dataset)))

    # Memory Bank
    print(colored('Build MemoryBank', 'blue'))
    memory_bank_train = MemoryBank(len(train_dataset), 2048, p['num_classes'],
                                   p['temperature'])
    memory_bank_train.cuda()
    memory_bank_val = MemoryBank(len(val_dataset), 2048, p['num_classes'],
                                 p['temperature'])
    memory_bank_val.cuda()

    # Load the official MoCoV2 checkpoint
    print(colored('Downloading moco v2 checkpoint', 'blue'))
    os.system(
        'wget -L https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar'
    )
    moco_state = torch.load('moco_v2_800ep_pretrain.pth.tar',
                            map_location='cpu')

    # Transfer moco weights
    print(colored('Transfer MoCo weights to model', 'blue'))
    new_state_dict = {}
    state_dict = moco_state['state_dict']
    for k in list(state_dict.keys()):
        # Copy backbone weights
        if k.startswith('module.encoder_q'
                        ) and not k.startswith('module.encoder_q.fc'):
            new_k = 'module.backbone.' + k[len('module.encoder_q.'):]
            new_state_dict[new_k] = state_dict[k]

        # Copy mlp weights
        elif k.startswith('module.encoder_q.fc'):
            new_k = 'module.contrastive_head.' + k[len('module.encoder_q.fc.'
                                                       ):]
            new_state_dict[new_k] = state_dict[k]

        else:
            raise ValueError('Unexpected key {}'.format(k))

    model.load_state_dict(new_state_dict)
    os.system('rm -rf moco_v2_800ep_pretrain.pth.tar')

    # Save final model
    print(colored('Save pretext model', 'blue'))
    torch.save(model.module.state_dict(), p['pretext_model'])
    model.module.contrastive_head = torch.nn.Identity(
    )  # In this case, we mine the neighbors before the MLP.

    # Mine the topk nearest neighbors (Train)
    # These will be used for training with the SCAN-Loss.
    topk = 50
    print(
        colored('Mine the nearest neighbors (Train)(Top-%d)' % (topk), 'blue'))
    transforms = get_val_transformations(p)
    train_dataset = get_train_dataset(p, transforms)
    fill_memory_bank(train_dataloader, model, memory_bank_train)
    indices, acc = memory_bank_train.mine_nearest_neighbors(topk)
    print('Accuracy of top-%d nearest neighbors on train set is %.2f' %
          (topk, 100 * acc))
    np.save(p['topk_neighbors_train_path'], indices)

    # Mine the topk nearest neighbors (Validation)
    # These will be used for validation.
    topk = 5
    print(colored('Mine the nearest neighbors (Val)(Top-%d)' % (topk), 'blue'))
    fill_memory_bank(val_dataloader, model, memory_bank_val)
    print('Mine the neighbors')
    indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
    print('Accuracy of top-%d nearest neighbors on val set is %.2f' %
          (topk, 100 * acc))
    np.save(p['topk_neighbors_val_path'], indices)
コード例 #2
0
ファイル: simclr.py プロジェクト: acl21/init-pools-dal
def main():

    # Retrieve config file
    p = create_config(args.config_env, args.config_exp)
    print(colored(p, 'red'))
    
    # Model
    print(colored('Retrieve model', 'blue'))
    model = get_model(p)
    print('Model is {}'.format(model.__class__.__name__))
    print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6))
    print(model)
    model = model.cuda()
   
    # CUDNN
    print(colored('Set CuDNN benchmark', 'blue')) 
    torch.backends.cudnn.benchmark = True
    
    # Dataset
    print(colored('Retrieve dataset', 'blue'))
    train_transforms = get_train_transformations(p)
    print('Train transforms:', train_transforms)
    val_transforms = get_val_transformations(p)
    print('Validation transforms:', val_transforms)
    train_dataset = get_train_dataset(p, train_transforms, to_augmented_dataset=True,
                                        split='train+unlabeled') # Split is for stl-10
    val_dataset = get_val_dataset(p, val_transforms) 
    train_dataloader = get_train_dataloader(p, train_dataset)
    val_dataloader = get_val_dataloader(p, val_dataset)
    print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset)))
    
    # Memory Bank
    print(colored('Build MemoryBank', 'blue'))
    base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval
    base_dataloader = get_val_dataloader(p, base_dataset) 
    memory_bank_base = MemoryBank(len(base_dataset), 
                                p['model_kwargs']['features_dim'],
                                p['num_classes'], p['criterion_kwargs']['temperature'])
    memory_bank_base.cuda()
    memory_bank_val = MemoryBank(len(val_dataset),
                                p['model_kwargs']['features_dim'],
                                p['num_classes'], p['criterion_kwargs']['temperature'])
    memory_bank_val.cuda()

    # Criterion
    print(colored('Retrieve criterion', 'blue'))
    criterion = get_criterion(p)
    print('Criterion is {}'.format(criterion.__class__.__name__))
    criterion = criterion.cuda()

    # Optimizer and scheduler
    print(colored('Retrieve optimizer', 'blue'))
    optimizer = get_optimizer(p, model)
    print(optimizer)
 
    # Checkpoint
    if os.path.exists(p['pretext_checkpoint']):
        print(colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue'))
        checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu')
        optimizer.load_state_dict(checkpoint['optimizer'])
        model.load_state_dict(checkpoint['model'])
        model.cuda()
        start_epoch = checkpoint['epoch']

    else:
        print(colored('No checkpoint file at {}'.format(p['pretext_checkpoint']), 'blue'))
        start_epoch = 0
        model = model.cuda()
    
    # Training
    print(colored('Starting main loop', 'blue'))
    for epoch in range(start_epoch, p['epochs']):
        print(colored('Epoch %d/%d' %(epoch, p['epochs']), 'yellow'))
        print(colored('-'*15, 'yellow'))

        # Adjust lr
        lr = adjust_learning_rate(p, optimizer, epoch)
        print('Adjusted learning rate to {:.5f}'.format(lr))
        
        # Train
        print('Train ...')
        simclr_train(train_dataloader, model, criterion, optimizer, epoch)

        # Fill memory bank
        print('Fill memory bank for kNN...')
        fill_memory_bank(base_dataloader, model, memory_bank_base)

        # Evaluate (To monitor progress - Not for validation)
        print('Evaluate ...')
        top1 = contrastive_evaluate(val_dataloader, model, memory_bank_base)
        print('Result of kNN evaluation is %.2f' %(top1)) 
        
        # Checkpoint
        print('Checkpoint ...')
        torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 
                    'epoch': epoch + 1}, p['pretext_checkpoint'])
        
        if epoch in [50, 75]:
            # Save final model
#             torch.save(model.state_dict(), p['pretext_model'])

            # Mine the topk nearest neighbors at the very end (Train) 
            # These will be served as input to the SCAN loss.
            print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue'))
            fill_memory_bank(base_dataloader, model, memory_bank_base)
            topk = 20
            print('Mine the nearest neighbors (Top-%d)' %(topk)) 
            indices, acc = memory_bank_base.mine_nearest_neighbors(topk)
            print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc))
#             np.save(p['topk_neighbors_train_path'], indices)   


            # Mine the topk nearest neighbors at the very end (Val)
            # These will be used for validation.
            print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue'))
            fill_memory_bank(val_dataloader, model, memory_bank_val)
            topk = 5
            print('Mine the nearest neighbors (Top-%d)' %(topk)) 
            indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
            print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc))
#             np.save(p['topk_neighbors_val_path'], indices)  



    # Save final model
    torch.save(model.state_dict(), p['pretext_model'])

    # Mine the topk nearest neighbors at the very end (Train) 
    # These will be served as input to the SCAN loss.
    print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue'))
    fill_memory_bank(base_dataloader, model, memory_bank_base)
    topk = 20
    print('Mine the nearest neighbors (Top-%d)' %(topk)) 
    indices, acc = memory_bank_base.mine_nearest_neighbors(topk)
    print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc))
    np.save(p['topk_neighbors_train_path'], indices)   

   
    # Mine the topk nearest neighbors at the very end (Val)
    # These will be used for validation.
    print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue'))
    fill_memory_bank(val_dataloader, model, memory_bank_val)
    topk = 5
    print('Mine the nearest neighbors (Top-%d)' %(topk)) 
    indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
    print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc))
    np.save(p['topk_neighbors_val_path'], indices)   
コード例 #3
0
def main():
    # Retrieve config file
    p = create_config(args.config_env, args.config_exp)
    print(colored(p, 'red'))

    # Model
    print(colored('Retrieve model', 'green'))
    model = get_model(p)
    print('Model is {}'.format(model.__class__.__name__))
    print('Model parameters: {:.2f}M'.format(
        sum(p.numel() for p in model.parameters()) / 1e6))
    print(model)
    model = model.to(device)

    # CUDNN
    print(colored('Set CuDNN benchmark', 'green'))
    torch.backends.cudnn.benchmark = True

    # Dataset
    print(colored('Retrieve dataset', 'green'))
    train_transforms = get_train_transformations(p)
    print('Train transforms:', train_transforms)
    val_transforms = get_val_transformations(p)
    print('Validation transforms:', val_transforms)
    train_dataset = get_train_dataset(p,
                                      train_transforms,
                                      to_augmented_dataset=True,
                                      split='train')  # Split is for stl-10
    val_dataset = get_val_dataset(p, val_transforms)
    train_dataloader = get_val_dataloader(p, train_dataset)
    val_dataloader = get_val_dataloader(p, val_dataset)
    print('Dataset contains {}/{} train/val samples'.format(
        len(train_dataset), len(val_dataset)))

    # Memory Bank
    print(colored('Build MemoryBank', 'green'))
    base_dataset = get_train_dataset(
        p, val_transforms, split='train')  # Dataset w/o augs for knn eval
    base_dataloader = get_val_dataloader(p, base_dataset)
    memory_bank_base = MemoryBank(len(base_dataset),
                                  p['model_kwargs']['features_dim'],
                                  p['num_classes'],
                                  p['criterion_kwargs']['temperature'])
    memory_bank_base.to(device)
    memory_bank_val = MemoryBank(len(val_dataset),
                                 p['model_kwargs']['features_dim'],
                                 p['num_classes'],
                                 p['criterion_kwargs']['temperature'])
    memory_bank_val.to(device)

    # Checkpoint
    if os.path.exists(p['pretext_checkpoint']):
        print(
            colored(
                'Restart from checkpoint {}'.format(p['pretext_checkpoint']),
                'green'))
        checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu')
        # optimizer.load_state_dict(checkpoint['optimizer'])
        model.load_state_dict(checkpoint['model'])
        model.to(device)
        # start_epoch = checkpoint['epoch']

    else:
        print(
            colored('No checkpoint file at {}'.format(p['pretext_checkpoint']),
                    'green'))
        start_epoch = 0
        model = model.to(device)

    # # Training
    # print(colored('Starting main loop', 'green'))
    # with torch.no_grad():
    #     model.eval()
    #     total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    #
    #     # progress_bar = tqdm(train_dataloader)
    #     for idx, batch in enumerate(train_dataloader):
    #         images = batch['image'].to(device, non_blocking=True)
    #         # target = batch['target'].to(device, non_blocking=True)
    #
    #         output = model(images)
    #         feature = F.normalize(output, dim=1)
    #         feature_bank.append(feature)
    #
    #         if idx % 25 == 0:
    #             print("Feature bank buidling : {} / {}".format(idx, len(train_dataset)/p["batch_size"]))
    #
    #     # [D, N]
    #     feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
    #     print(colored("Feature bank created. Similarity index starts now", "green"))
    #     print(feature_bank.size())
    #
    #     for idx, batch in enumerate(train_dataloader):
    #
    #         images = batch['image'].to(device, non_blocking=True)
    #         # target = batch['target'].to(device, non_blocking=True)
    #
    #         output = model(images)
    #         feature = F.normalize(output, dim=1)
    #
    #         sim_indices = knn_predict(feature, feature_bank, "", "", 10, 0.1)
    #
    #         print(sim_indices)
    #
    #         if idx == 10:
    #             break

    # # Mine the topk nearest neighbors at the very end (Train)
    # # These will be served as input to the SCAN loss.
    # print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'green'))
    # fill_memory_bank(base_dataloader, model, memory_bank_base)
    # topk = 20
    # print('Mine the nearest neighbors (Top-%d)' %(topk))
    # indices, acc = memory_bank_base.mine_nearest_neighbors(topk)
    # print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc))
    # np.save(p['topk_neighbors_train_path'], indices)

    # Mine the topk nearest neighbors at the very end (Val)
    # These will be used for validation.
    print(
        colored('Fill memory bank for mining the nearest neighbors (val) ...',
                'green'))
    fill_memory_bank(val_dataloader, model, memory_bank_val)
    topk = 5
    print('Mine the nearest neighbors (Top-%d)' % (topk))
    indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
    print('Accuracy of top-%d nearest neighbors on val set is %.2f' %
          (topk, 100 * acc))
    np.save(p['topk_neighbors_val_path'], indices)
コード例 #4
0
def main():

    # Retrieve config file
    p = create_config(args.config_env, args.config_exp)
    print(colored(p, 'red'))

    # Model
    print(colored('Retrieve model', 'blue'))
    model = get_model(p)
    print('Model is {}'.format(model.__class__.__name__))
    print('Model parameters: {:.2f}M'.format(
        sum(p.numel() for p in model.parameters()) / 1e6))
    print(model)
    model = model.cuda()

    # CUDNN
    print(colored('Set CuDNN benchmark', 'blue'))
    torch.backends.cudnn.benchmark = True

    # Dataset
    val_transforms = get_val_transformations(p)
    print('Validation transforms:', val_transforms)
    val_dataset = get_val_dataset(p, val_transforms)
    val_dataloader = get_val_dataloader(p, val_dataset)
    print('Dataset contains {} val samples'.format(len(val_dataset)))

    # Memory Bank
    print(colored('Build MemoryBank', 'blue'))
    base_dataset = get_train_dataset(
        p, val_transforms, split='train')  # Dataset w/o augs for knn eval
    base_dataloader = get_val_dataloader(p, base_dataset)
    memory_bank_base = MemoryBank(len(base_dataset),
                                  p['model_kwargs']['features_dim'],
                                  p['num_classes'],
                                  p['criterion_kwargs']['temperature'])
    memory_bank_base.cuda()
    memory_bank_val = MemoryBank(len(val_dataset),
                                 p['model_kwargs']['features_dim'],
                                 p['num_classes'],
                                 p['criterion_kwargs']['temperature'])
    memory_bank_val.cuda()

    # Checkpoint
    assert os.path.exists(p['pretext_checkpoint'])
    print(
        colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']),
                'blue'))
    checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu')
    model.load_state_dict(checkpoint)
    model.cuda()

    # Save model
    torch.save(model.state_dict(), p['pretext_model'])

    # Mine the topk nearest neighbors at the very end (Train)
    # These will be served as input to the SCAN loss.
    print(
        colored(
            'Fill memory bank for mining the nearest neighbors (train) ...',
            'blue'))
    fill_memory_bank(base_dataloader, model, memory_bank_base)
    topk = 20
    print('Mine the nearest neighbors (Top-%d)' % (topk))
    indices, acc = memory_bank_base.mine_nearest_neighbors(topk)
    print('Accuracy of top-%d nearest neighbors on train set is %.2f' %
          (topk, 100 * acc))
    np.save(p['topk_neighbors_train_path'], indices)

    # Mine the topk nearest neighbors at the very end (Val)
    # These will be used for validation.
    print(
        colored('Fill memory bank for mining the nearest neighbors (val) ...',
                'blue'))
    fill_memory_bank(val_dataloader, model, memory_bank_val)
    topk = 5
    print('Mine the nearest neighbors (Top-%d)' % (topk))
    indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
    print('Accuracy of top-%d nearest neighbors on val set is %.2f' %
          (topk, 100 * acc))
    np.save(p['topk_neighbors_val_path'], indices)
コード例 #5
0
def main():
    # Retrieve config file
    p = create_config(args.config_env, args.config_exp)
    print(colored(p, 'red'))
    
    
    # Model
    print(colored('Retrieve model', 'green'))
    model = get_model(p)
    print('Model is {}'.format(model.__class__.__name__))
    print(model)
    # model = torch.nn.DataParallel(model)
    model = model.to(device)
   
    
    # CUDNN
    print(colored('Set CuDNN benchmark', 'green'))
    torch.backends.cudnn.benchmark = True
    
    
    # Dataset
    print(colored('Retrieve dataset', 'green'))
    transforms = get_val_transformations(p)
    train_dataset = get_train_dataset(p, transforms) 
    val_dataset = get_val_dataset(p, transforms)
    train_dataloader = get_val_dataloader(p, train_dataset)
    val_dataloader = get_val_dataloader(p, val_dataset)
    print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset)))
    
   
    # Memory Bank
    print(colored('Build MemoryBank', 'green'))
    memory_bank_train = MemoryBank(len(train_dataset), 2048, p['num_classes'], p['temperature'])
    memory_bank_train.to(device)
    memory_bank_val = MemoryBank(len(val_dataset), 2048, p['num_classes'], p['temperature'])
    memory_bank_val.to(device)

    
    # Load the official MoCoV2 checkpoint
    print(colored('Downloading moco v2 checkpoint', 'green'))
    # os.system('wget -L https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar')
    # Uploaded the model to Mist : Johan
    moco_state = torch.load(main_dir + model_dir + 'moco_v2_800ep_pretrain.pth.tar', map_location=device)

    
    # Transfer moco weights
    print(colored('Transfer MoCo weights to model', 'green'))
    new_state_dict = {}
    state_dict = moco_state['state_dict']
    # for k in list(state_dict.keys()):
    #     # Copy backbone weights
    #     if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
    #         new_k = 'module.backbone.' + k[len('module.encoder_q.'):]
    #         new_state_dict[new_k] = state_dict[k]
    #
    #     # Copy mlp weights
    #     elif k.startswith('module.encoder_q.fc'):
    #         new_k = 'module.contrastive_head.' + k[len('module.encoder_q.fc.'):]
    #         new_state_dict[new_k] = state_dict[k]
    #
    #     else:
    #         raise ValueError('Unexpected key {}'.format(k))

    #Changed by Johan
    for k, v in state_dict.items():
        if "conv" in k or "bn" in k or "layer" in k:
            new_k = "backbone." + k.split("module.encoder_q.")[1]
            new_state_dict[new_k] = v
        else:
            new_k = "contrastive_head." + k.split("module.encoder_q.fc.")[1]
            new_state_dict[new_k] = v

    model.load_state_dict(new_state_dict)
    # os.system('rm -rf moco_v2_800ep_pretrain.pth.tar')
   
 
    # Save final model
    print(colored('Save pretext model', 'green'))
    torch.save(model.state_dict(), p['pretext_model'])
    # model.contrastive_head = torch.nn.Identity() # In this case, we mine the neighbors before the MLP.
    model.contrastive_head = Identity()
コード例 #6
0
def main():
    # Read config file
    print(colored('Read config file {} ...'.format(args.config_exp), 'blue'))
    with open(args.config_exp, 'r') as stream:
        config = yaml.safe_load(stream)
    config[
        'batch_size'] = 512  # To make sure we can evaluate on a single 1080ti
    print(config)

    # Get dataset
    print(colored('Get validation dataset ...', 'blue'))
    transforms = get_val_transformations(config)
    dataset = get_val_dataset(config, transforms)
    dataloader = get_val_dataloader(config, dataset)
    print('Number of samples: {}'.format(len(dataset)))

    # Get model
    print(colored('Get model ...', 'blue'))
    model = get_model(config)
    print(model)

    # Read model weights
    print(colored('Load model weights ...', 'blue'))
    state_dict = torch.load(args.model, map_location='cpu')

    if config['setup'] in ['simclr', 'moco', 'selflabel']:
        model.load_state_dict(state_dict)

    elif config['setup'] == 'scan':
        model.load_state_dict(state_dict['model'])

    else:
        raise NotImplementedError

    # CUDA
    model.cuda()

    # Perform evaluation
    if config['setup'] in ['simclr', 'moco']:
        print(
            colored(
                'Perform evaluation of the pretext task (setup={}).'.format(
                    config['setup']), 'blue'))
        print('Create Memory Bank')
        if config['setup'] == 'simclr':  # Mine neighbors after MLP
            memory_bank = MemoryBank(len(dataset),
                                     config['model_kwargs']['features_dim'],
                                     config['num_classes'],
                                     config['criterion_kwargs']['temperature'])

        else:  # Mine neighbors before MLP
            memory_bank = MemoryBank(len(dataset),
                                     config['model_kwargs']['features_dim'],
                                     config['num_classes'],
                                     config['temperature'])
        memory_bank.cuda()

        print('Fill Memory Bank')
        fill_memory_bank(dataloader, model, memory_bank)

        print('Mine the nearest neighbors')
        for topk in [1, 5, 20]:  # Similar to Fig 2 in paper
            _, acc = memory_bank.mine_nearest_neighbors(topk)
            print(
                'Accuracy of top-{} nearest neighbors on validation set is {:.2f}'
                .format(topk, 100 * acc))

    elif config['setup'] in ['scan', 'selflabel']:
        print(
            colored(
                'Perform evaluation of the clustering model (setup={}).'.
                format(config['setup']), 'blue'))
        head = state_dict['head'] if config['setup'] == 'scan' else 0
        predictions, features = get_predictions(config,
                                                dataloader,
                                                model,
                                                return_features=True)
        clustering_stats = hungarian_evaluate(head,
                                              predictions,
                                              dataset.classes,
                                              compute_confusion_matrix=True)
        print(clustering_stats)
        if args.visualize_prototypes:
            prototype_indices = get_prototypes(config, predictions[head],
                                               features, model)
            visualize_indices(prototype_indices, dataset,
                              clustering_stats['hungarian_match'])
    else:
        raise NotImplementedError