Пример #1
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Load the model with BoundedParameter for weight perturbation.
    model_ori = models.Models['mlp_3layer_weight_perturb']()

    epoch = 0
    ## Load a checkpoint, if requested.
    if args.load:
        checkpoint = torch.load(args.load)
        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']
        opt_state = None
        try:
            opt_state = checkpoint['optimizer']
        except KeyError:
            print('no opt_state found')
        for k, v in state_dict.items():
            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(v).any().cpu().numpy() == 0
        model_ori.load_state_dict(state_dict)
        logger.log('Checkpoint loaded: {}'.format(args.load))

    ## Step 2: Prepare dataset as usual
    dummy_input = torch.randn(1, 1, 28, 28)
    train_data,  test_data = mnist_loaders(datasets.MNIST, batch_size=args.batch_size, ratio=args.ratio)
    train_data.mean = test_data.mean = torch.tensor([0.0])
    train_data.std = test_data.std = torch.tensor([1.0])

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori, dummy_input, bound_opts={'relu':args.bound_opts}, device=args.device)
    final_name1 = model.final_name
    model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)),
                               bound_opts= { 'relu': args.bound_opts, 'loss_fusion': True }, device=args.device)

    # after CrossEntropyWrapper, the final name will change because of one more input node in CrossEntropyWrapper
    final_name2 = model_loss._modules[final_name1].output_name[0]
    assert type(model._modules[final_name1]) == type(model_loss._modules[final_name2])
    if args.multigpu:
        model_loss = BoundDataParallel(model_loss)
    model_loss.ptb = model.ptb = model_ori.ptb # Perturbation on the parameters

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    if args.opt == 'ADAM':
        opt = optim.Adam(model_loss.parameters(), lr=args.lr, weight_decay=0.01)
    elif args.opt == 'SGD':
        opt = optim.SGD(model_loss.parameters(), lr=args.lr, weight_decay=0.01)

    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=args.lr_decay_milestones, gamma=0.1)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    logger.log(str(model_ori))

    # Skip epochs if we continue training from a checkpoint.
    if epoch > 0:
        epoch_length = int((len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size)
        eps_scheduler.set_epoch_length(epoch_length)
        eps_scheduler.train()
        for i in range(epoch):
            lr_scheduler.step()
            eps_scheduler.step_epoch(verbose=True)
            for j in range(epoch_length):
                eps_scheduler.step_batch()
        logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))

    if args.load:
        if opt_state:
            opt.load_state_dict(opt_state)
            logger.log('resume opt_state')

    ## Step 5: start training.
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model, 1, test_data, eps_scheduler, norm, False, None, 'CROWN-IBP', loss_fusion=False, final_node_name=None)
    else:
        timer = 0.0
        best_loss = 1e10
        # Main training loop
        for t in range(epoch + 1, args.num_epochs+1):
            logger.log("Epoch {}, learning rate {}".format(t, lr_scheduler.get_last_lr()))
            start_time = time.time()

            # Training one epoch
            Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=True)
            lr_scheduler.step()
            epoch_time = time.time() - start_time
            timer += epoch_time
            logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))

            logger.log("Evaluating...")
            torch.cuda.empty_cache()

            # remove 'model.' in state_dict (hack for saving models so far...)
            state_dict_loss = model_loss.state_dict()
            state_dict = {}
            for name in state_dict_loss:
                assert (name.startswith('model.'))
                state_dict[name[6:]] = state_dict_loss[name]

            # Test one epoch.
            with torch.no_grad():
                m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, args.bound_type,
                            loss_fusion=False, final_node_name=final_name2)

            # Save checkpoints.
            save_dict = {'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict()}
            if not os.path.exists('saved_models'):
                os.mkdir('saved_models')
            if t < int(eps_scheduler.params['start']):
                torch.save(save_dict, 'saved_models/natural_' + exp_name)
            elif t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']):
                current_loss = m.avg('Loss')
                if current_loss < best_loss:
                    best_loss = current_loss
                    torch.save(save_dict, 'saved_models/' + exp_name + '_best_' + str(best_loss)[:6])
                else:
                    torch.save(save_dict, 'saved_models/' + exp_name)
            else:
                torch.save(save_dict, 'saved_models/' + exp_name)
            torch.cuda.empty_cache()
def train(wandb_track,
          experiment_name,
          epochs,
          task,
          gpu_num=0,
          pretrained='',
          margin=0.4,
          losstype='deepcca'):
    """Train joint embedding networks."""

    epochs = int(epochs)
    gpu_num = int(gpu_num)
    margin = float(margin)

    # Setup the results and device.
    results_dir = setup_dirs(experiment_name)
    if not os.path.exists(results_dir + 'train_results/'):
        os.makedirs(results_dir + 'train_results/')
    train_results_dir = results_dir + 'train_results/'
    device = setup_device(gpu_num)

    #### Hyperparameters #####
    #Initialize wandb
    if wandb_track == 1:
        import wandb
        wandb.init(project=experiment_name)
        config = wandb.config
        config.epochs = epochs

    with open(results_dir + 'hyperparams_train.txt', 'w') as f:
        f.write('Command used to run: python ')
        f.write(' '.join(sys.argv))
        f.write('\n')
        f.write('device in use: ' + str(device))
        f.write('\n')
        f.write('--experiment_name ' + str(experiment_name))
        f.write('\n')
        f.write('--epochs ' + str(epochs))
        f.write('\n')

    # Setup data loaders and models.
    if task == 'cifar10':
        train_loader, test_loader = cifar10_loaders()
        model_A = CIFAREmbeddingNet()
        model_B = CIFAREmbeddingNet()
    elif task == 'mnist':
        train_loader, test_loader = mnist_loaders()
        model_A = MNISTEmbeddingNet()
        model_B = MNISTEmbeddingNet()
    elif task == 'uw':
        uw_data = 'bert'
        train_loader, test_loader = uw_loaders(uw_data)
        if uw_data == 'bert':
            model_A = RowNet(3072, embed_dim=1024)  # Language.
            model_B = RowNet(4096, embed_dim=1024)  # Vision.

    # Finish model setup.
    if pretrained == 'pretrained':  # If we want to load pretrained models to continue training.
        print('Starting from pretrained networks.')
        model_A.load_state_dict(
            torch.load(train_results_dir + 'model_A_state.pt'))
        model_B.load_state_dict(
            torch.load(train_results_dir + 'model_B_state.pt'))

    print('Starting from scratch to train networks.')

    model_A.to(device)
    model_B.to(device)

    # Initialize the optimizers and loss function.
    optimizer_A = torch.optim.Adam(model_A.parameters(), lr=0.00001)
    optimizer_B = torch.optim.Adam(model_B.parameters(), lr=0.00001)

    # Add learning rate scheduling.
    def lr_lambda(e):
        if e < 50:
            return 0.001
        elif e < 100:
            return 0.0001
        else:
            return 0.00001

    scheduler_A = torch.optim.lr_scheduler.LambdaLR(optimizer_A, lr_lambda)
    scheduler_B = torch.optim.lr_scheduler.LambdaLR(optimizer_B, lr_lambda)

    # Track batch losses.
    loss_hist = []

    # Put models into training mode.
    model_A.train()
    model_B.train()

    # Train.
    # wandb
    if wandb_track == 1:
        wandb.watch(model_A, log="all")
        wandb.watch(model_B, log="all")
    epoch_list = []  # in order to save epoch in a pickle file
    loss_list = []  # in order to save loss in a pickle file
    for epoch in tqdm(range(epochs)):
        epoch_loss = 0.0
        counter = 0
        for data in train_loader:
            data_a = data[0].to(device)
            data_b = data[1].to(device)
            #label = data[2]

            # Zero the parameter gradients.
            optimizer_A.zero_grad()
            optimizer_B.zero_grad()

            # Forward.
            if losstype == 'deepcca':  # Based on Galen Andrew's Deep CCA
                # data_a is from domain A, and data_b is the paired data from domain B.
                embedding_a = model_A(data_a)
                embedding_b = model_B(data_b)
                loss = deepcca(embedding_a,
                               embedding_b,
                               device,
                               use_all_singular_values=True,
                               outdim_size=128)

            # Backward.
            loss.backward()

            # Update.
            optimizer_A.step()
            optimizer_B.step()

            # Save batch loss. Since we are minimizing -corr the loss is negative.
            loss_hist.append(-1 * loss.item())

            epoch_loss += embedding_a.shape[0] * loss.item()

            #reporting progress
            counter += 1
            if counter % 64 == 0:
                print('epoch:', epoch, 'loss:', loss.item())
                if wandb_track == 1:
                    wandb.log({"epoch": epoch, "loss": loss})

        # Save network state at each epoch.
        torch.save(model_A.state_dict(),
                   train_results_dir + 'model_A_state.pt')
        torch.save(model_B.state_dict(),
                   train_results_dir + 'model_B_state.pt')

        #since the batch size is 1 therefore: len(trainloader)==counter
        print('*********** epoch is finished ***********')
        epoch_loss = -1 * epoch_loss
        print('epoch: ', epoch, 'loss(correlation): ', (epoch_loss) / counter)
        epoch_list.append(epoch + 1)
        loss_list.append(epoch_loss / counter)
        pickle.dump(([epoch_list, loss_list]),
                    open(train_results_dir + 'epoch_loss.pkl', "wb"))
        Visualize(train_results_dir + 'epoch_loss.pkl', 'Correlation History',
                  True, 'epoch', 'Correlation (log scale)', None, 'log', None,
                  (14, 7), train_results_dir + 'Figures/')
        # Update learning rate schedulers.
        scheduler_A.step()
        scheduler_B.step()

    # Plot and save batch loss history.
    pickle.dump(([loss_hist[::10]]),
                open(train_results_dir + 'epoch_corr.pkl', "wb"))
    Visualize(train_results_dir + 'epoch_corr.pkl', 'Correlation Batch', False,
              'Batch', 'Correlation (log scale)', None, 'log', None, (14, 7),
              train_results_dir + 'Figures/')

    #### Learn the transformations for CCA ####
    if losstype == "CCA":
        a_base = []
        b_base = []
        no_model = True

        if no_model:  # without using model: using raw data without featurization
            for data in train_loader:
                x = data[0].to(device)
                y = data[1].to(device)
                if task == 'uw':
                    a_base.append(x)
                    b_base.append(y)
                else:
                    a_base.append(x.cpu().detach().numpy())
                    b_base.append(y.cpu().detach().numpy())
        else:
            import torchvision.models as models
            #Either use these models, or use trained models with triplet loss
            res18_model = models.resnet18(pretrained=True)
            #changing the first layer of ResNet to accept images with 1 channgel instead of 3.
            res18_model.conv1 = torch.nn.Conv2d(1,
                                                64,
                                                kernel_size=7,
                                                stride=2,
                                                padding=3,
                                                bias=False)
            # Select the desired layers
            model_A = torch.nn.Sequential(*list(res18_model.children())[:-2])
            model_B = torch.nn.Sequential(*list(res18_model.children())[:-2])
            model_A.eval()
            model_B.eval()
            for data in train_loader:
                x = data[0].to(device)  # Domain A
                y = data[1].to(device)  # Domain B
                a_base.append(model_A(x).cpu().detach().numpy())
                b_base.append(model_B(y).cpu().detach().numpy())

        # Concatenate predictions.
        a_base = np.concatenate(a_base, axis=0)
        b_base = np.concatenate(b_base, axis=0)
        a_base = np.squeeze(a_base)
        b_base = np.squeeze(b_base)

        if no_model:
            new_a_base = []
            new_b_base = []
            for i in range(len(a_base)):
                new_a_base.append(a_base[i, :, :].flatten())
                new_b_base.append(b_base[i, :, :].flatten())
            new_a_base = np.asarray(new_a_base)
            new_b_base = np.asarray(new_b_base)
            a_base = new_a_base
            b_base = new_b_base

            print('Finished reshaping data, the shape is:', new_a_base.shape)

        from sklearn.cross_decomposition import CCA
        from joblib import dump
        components = 128
        cca = CCA(n_components=components)
        cca.max_iter = 5000
        cca.fit(a_base, b_base)
        dump(cca, 'Learned_CCA.joblib')
    #### End of CCA fit to find the transformations ####

    print('Training Done!')
def test(experiment_name,
         task,
         gpu_num=0,
         pretrained='',
         margin=0.4,
         losstype='deepcca'):
    cosined = False
    embed_dim = 1024
    gpu_num = int(gpu_num)
    margin = float(margin)

    # Setup the results and device.
    results_dir = setup_dirs(experiment_name)
    if not os.path.exists(results_dir + 'test_results/'):
        os.makedirs(results_dir + 'test_results/')
    test_results_dir = results_dir + 'test_results/'

    device = setup_device(gpu_num)

    #### Hyperparameters #####
    #Initialize wandb
    #import wandb
    #wandb.init(project=experiment_name)
    #config = wandb.config

    with open(results_dir + 'hyperparams_test.txt', 'w') as f:
        f.write('Command used to run: python ')
        f.write(' '.join(sys.argv))
        f.write('\n')
        f.write('device in use: ' + str(device))
        f.write('\n')
        f.write('--experiment_name ' + str(experiment_name))
        f.write('\n')

    # Setup data loaders and models based on task.
    if task == 'cifar10':
        train_loader, test_loader = cifar10_loaders()
        model_A = CIFAREmbeddingNet()
        model_B = CIFAREmbeddingNet()
    elif task == 'mnist':
        train_loader, test_loader = mnist_loaders()
        model_A = MNISTEmbeddingNet()
        model_B = MNISTEmbeddingNet()
    elif task == 'uw':
        uw_data = 'bert'
        train_loader, test_loader = uw_loaders(uw_data)
        if uw_data == 'bert':
            model_A = RowNet(3072, embed_dim=1024)  # Language.
            model_B = RowNet(4096, embed_dim=1024)  # Vision.

    # Finish model setup.
    model_A.load_state_dict(
        torch.load(results_dir + 'train_results/model_A_state.pt'))
    model_B.load_state_dict(
        torch.load(results_dir + 'train_results/model_B_state.pt'))
    model_A.to(device)
    model_B.to(device)
    # Put models into evaluation mode.
    model_A.eval()
    model_B.eval()
    """For UW data."""
    ## we use train data to calculate the threshhold for distance.
    a_train = []
    b_train = []
    # loading saved embeddings to be faster
    a_train = load_embeddings(test_results_dir + 'lang_embeds_train.npy')
    b_train = load_embeddings(test_results_dir + 'img_embeds_train.npy')

    # Iterate through the train data.
    if a_train is None or b_train is None:
        a_train = []
        b_train = []
        print(
            "Computing embeddings for train data to calculate threshhold for distance"
        )
        for data in train_loader:
            anchor_data = data[0].to(device)
            positive_data = data[1].to(device)
            label = data[2]
            a_train.append(
                model_A(anchor_data.to(device)).cpu().detach().numpy())
            b_train.append(
                model_B(positive_data.to(device)).cpu().detach().numpy())
        print("Finished Computing embeddings for train data")
    #saving embeddings if not already saved
    save_embeddings(test_results_dir + 'lang_embeds_train.npy', a_train)
    save_embeddings(test_results_dir + 'img_embeds_train.npy', b_train)

    a_train = np.concatenate(a_train, axis=0)
    b_train = np.concatenate(b_train, axis=0)

    # Test data
    # For accumulating predictions to check embedding visually using test set.
    # a is embeddings from domain A, b is embeddings from domain B, ys is their labels
    a = []
    b = []
    ys = []
    instance_data = []

    # loading saved embeddings to be faster
    a = load_embeddings(test_results_dir + 'lang_embeds.npy')
    b = load_embeddings(test_results_dir + 'img_embeds.npy')
    if a is None or b is None:
        compute_test_embeddings = True
        a = []
        b = []

    # Iterate through the test data.
    print("computing embeddings for test data")
    for data in test_loader:
        language_data, vision_data, object_name, instance_name = data
        language_data = language_data.to(device)
        vision_data = vision_data.to(device)
        instance_data.extend(instance_name)
        if compute_test_embeddings:
            a.append(
                model_A(language_data).cpu().detach().numpy())  # Language.
            b.append(model_B(vision_data).cpu().detach().numpy())  # Vision.
        ys.extend(object_name)
    print("finished computing embeddings for test data")
    # Convert string labels to ints.
    labelencoder = LabelEncoder()
    labelencoder.fit(ys)
    ys = labelencoder.transform(ys)

    #saving embeddings if not already saved
    save_embeddings(test_results_dir + 'lang_embeds.npy', a)
    save_embeddings(test_results_dir + 'img_embeds.npy', b)

    # Concatenate predictions.
    a = np.concatenate(a, axis=0)
    b = np.concatenate(b, axis=0)
    ab = np.concatenate((a, b), axis=0)

    ground_truth, predicted, distance = object_identification_task_classifier(
        a, b, ys, a_train, b_train, lamb_std=1, cosine=cosined)

    #### Retrieval task by giving an image and finding the closest word descriptions ####
    ground_truth_word, predicted_word, distance_word = object_identification_task_classifier(
        b, a, ys, b_train, a_train, lamb_std=1, cosine=cosined)
    with open('retrieval_non_pro.csv', mode='w') as retrieval_non_pro:
        csv_file_writer = csv.writer(retrieval_non_pro,
                                     delimiter=',',
                                     quotechar='"',
                                     quoting=csv.QUOTE_MINIMAL)
        csv_file_writer.writerow(
            ['image', 'language', 'predicted', 'ground truth'])
        for i in range(50):
            csv_file_writer.writerow([
                instance_data[0], instance_data[i], predicted_word[0][i],
                ground_truth_word[0][i]
            ])

    precisions = []
    recalls = []
    f1s = []
    precisions_pos = []
    recalls_pos = []
    f1s_pos = []
    #print(classification_report(oit_res[i], 1/np.arange(1,len(oit_res[i])+1) > 0.01))
    for i in range(len(ground_truth)):
        p, r, f, s = precision_recall_fscore_support(ground_truth[i],
                                                     predicted[i],
                                                     warn_for=(),
                                                     average='micro')
        precisions.append(p)
        recalls.append(r)
        f1s.append(f)
        p, r, f, s = precision_recall_fscore_support(ground_truth[i],
                                                     predicted[i],
                                                     warn_for=(),
                                                     average='binary')
        precisions_pos.append(p)
        recalls_pos.append(r)
        f1s_pos.append(f)

    print('\n ')
    print(experiment_name + '_' + str(embed_dim))
    print('MRR,    KNN,    Corr,   Mean F1,    Mean F1 (pos only)')
    print('%.3g & %.3g & %.3g & %.3g & %.3g' %
          (mean_reciprocal_rank(
              a, b, ys, cosine=cosined), knn(a, b, ys, k=5, cosine=cosined),
           corr_between(a, b, cosine=cosined), np.mean(f1s), np.mean(f1s_pos)))

    plt.figure(figsize=(14, 7))
    for i in range(len(ground_truth)):
        fpr, tpr, thres = roc_curve(ground_truth[i],
                                    [1 - e for e in distance[i]],
                                    drop_intermediate=True)
        plt.plot(fpr, tpr, alpha=0.08, color='r')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.savefig(test_results_dir + '_' + str(embed_dim) + '_ROC.svg')

    # Pick a pair, plot distance in A vs distance in B. Should be correlated.
    a_dists = []
    b_dists = []
    for _ in range(3000):
        i1 = random.randrange(len(a))
        i2 = random.randrange(len(a))
        a_dists.append(euclidean(a[i1], a[i2]))
        b_dists.append(euclidean(b[i1], b[i2]))
    #     a_dists.append(cosine(a[i1], a[i2]))
    #     b_dists.append(cosine(b[i1], b[i2]))

    # Plot.
    plt.figure(figsize=(14, 14))
    #plt.title('Check Distance Correlation Between Domains')
    plt.xlim([0, 3])
    plt.ylim([0, 3])
    # plt.xlim([0,max(a_dists)])
    # plt.ylim([0,max(b_dists)])
    # plt.xlabel('Distance in Domain A')
    # plt.ylabel('Distance in Domain B')
    plt.xlabel('Distance in Language Domain')
    plt.ylabel('Distance in Vision Domain')
    #plt.plot(a_dists_norm[0],b_dists_norm[0],'.')
    #plt.plot(np.arange(0,2)/20,np.arange(0,2)/20,'k-',lw=3)
    plt.plot(a_dists, b_dists, 'o', alpha=0.5)
    plt.plot(np.arange(0, 600), np.arange(0, 600), 'k--', lw=3, alpha=0.5)
    #plt.text(-0.001, -0.01, 'Corr: %.3f'%(pearsonr(a_dists,b_dists)[0]),  fontsize=20)
    plt.savefig(test_results_dir + '_' + str(embed_dim) + '_CORR.svg')

    # Inspect embedding distances.
    clas = 5  # Base class.
    i_clas = [i for i in range(len(ys)) if ys[i].item() == clas]
    i_clas_2 = np.random.choice(i_clas, len(i_clas), replace=False)

    clas_ref = 4  # Comparison class.
    i_clas_ref = [i for i in range(len(ys)) if ys[i].item() == clas_ref]

    ac = np.array([a[i] for i in i_clas])
    bc = np.array([b[i] for i in i_clas])

    ac2 = np.array([a[i] for i in i_clas_2])
    bc2 = np.array([b[i] for i in i_clas_2])

    ac_ref = np.array([a[i] for i in i_clas_ref])
    aa_diff_ref = norm(ac[:min(len(ac), len(ac_ref))] -
                       ac_ref[:min(len(ac), len(ac_ref))],
                       ord=2,
                       axis=1)

    ab_diff = norm(ac - bc2, ord=2, axis=1)
    aa_diff = norm(ac - ac2, ord=2, axis=1)
    bb_diff = norm(bc - bc2, ord=2, axis=1)

    # aa_diff_ref = [cosine(ac[:min(len(ac),len(ac_ref))][i],ac_ref[:min(len(ac),len(ac_ref))][i]) for i in range(len(ac[:min(len(ac),len(ac_ref))]))]

    # ab_diff = [cosine(ac[i],bc2[i]) for i in range(len(ac))]
    # aa_diff = [cosine(ac[i],ac2[i]) for i in range(len(ac))]
    # bb_diff = [cosine(bc[i],bc2[i]) for i in range(len(ac))]

    bins = np.linspace(0, 0.1, 100)

    plt.figure(figsize=(14, 7))
    plt.hist(ab_diff, bins, alpha=0.5, label='between embeddings')
    plt.hist(aa_diff, bins, alpha=0.5, label='within embedding A')
    plt.hist(bb_diff, bins, alpha=0.5, label='within embedding B')

    plt.hist(aa_diff_ref,
             bins,
             alpha=0.5,
             label='embedding A, from class ' + str(clas_ref))

    plt.title('Embedding Distances - Class: ' + str(clas))
    plt.xlabel('L2 Distance')
    plt.ylabel('Count')
    plt.legend()

    #labelencoder.classes_
    classes_to_keep = [36, 6, 9, 46, 15, 47, 50, 22, 26, 28]
    print(labelencoder.inverse_transform(classes_to_keep))

    ab_norm = [
        e for i, e in enumerate(ab) if ys[i % len(ys)] in classes_to_keep
    ]
    ys_norm = [e for e in ys if e in classes_to_keep]

    color_index = {list(set(ys_norm))[i]: i
                   for i in range(len(set(ys_norm)))}  #set(ys_norm)
    markers = ["o", "v", "^", "s", "*", "+", "x", "D", "h", "4"]
    marker_index = {
        list(set(ys_norm))[i]: markers[i]
        for i in range(len(set(ys_norm)))
    }

    embedding = umap.UMAP(n_components=2).fit_transform(
        ab_norm)  # metric='cosine'
    # Plot UMAP embedding of embeddings for all classes.
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

    mid = len(ys_norm)

    ax1.set_title('Language UMAP')
    for e in list(set(ys_norm)):
        x1 = [
            embedding[:mid, 0][i] for i in range(len(ys_norm))
            if ys_norm[i] == e
        ]
        x2 = [
            embedding[:mid, 1][i] for i in range(len(ys_norm))
            if ys_norm[i] == e
        ]
        ax1.scatter(
            x1,
            x2,
            marker=marker_index[int(e)],
            alpha=0.5,
            c=[sns.color_palette("colorblind", 10)[color_index[int(e)]]],
            label=labelencoder.inverse_transform([int(e)])[0])
    ax1.set_xlim([min(embedding[:, 0]) - 4, max(embedding[:, 0]) + 4])
    ax1.set_ylim([min(embedding[:, 1]) - 4, max(embedding[:, 1]) + 4])
    ax1.grid(True)
    ax1.legend(loc='upper center',
               bbox_to_anchor=(1.1, -0.08),
               fancybox=True,
               shadow=True,
               ncol=5)

    ax2.set_title('Vision UMAP')
    for e in list(set(ys_norm)):
        x1 = [
            embedding[mid::, 0][i] for i in range(len(ys_norm))
            if ys_norm[i] == e
        ]
        x2 = [
            embedding[mid::, 1][i] for i in range(len(ys_norm))
            if ys_norm[i] == e
        ]
        ax2.scatter(
            x1,
            x2,
            marker=marker_index[int(e)],
            alpha=0.5,
            c=[sns.color_palette("colorblind", 10)[color_index[int(e)]]])
    ax2.set_xlim([min(embedding[:, 0]) - 4, max(embedding[:, 0]) + 4])
    ax2.set_ylim([min(embedding[:, 1]) - 4, max(embedding[:, 1]) + 4])
    ax2.grid(True)

    plt.savefig(test_results_dir + '_' + str(embed_dim) + '_UMAP_wl.svg',
                bbox_inches='tight')