Esempio n. 1
0
    def __init__(self, args, output_channels=40):
        super(PaiNet, self).__init__()
        self.args = args
        self.k = args.k
        num_kernel = 9  # xyz*2+1 # args.k

        self.kernels = nn.Parameter(torch.tensor(
            fibonacci_sphere(num_kernel)).transpose(0, 1),
                                    requires_grad=False)
        self.one_padding = nn.Parameter(torch.zeros(self.k, num_kernel),
                                        requires_grad=False)
        self.one_padding.data[0, 0] = 1
        self.activation = nn.LeakyReLU(negative_slope=0.2)
        self.softmax = Sparsemax(dim=-1)

        self.conv1 = PaiConvDG(3, 64, self.k, num_kernel)
        self.conv2 = PaiConvDG(64, 64, self.k, num_kernel)
        self.conv3 = PaiConvDG(64, 128, self.k, num_kernel)
        self.conv4 = PaiConvDG(128, 256, self.k, num_kernel)

        self.bn5 = nn.BatchNorm1d(args.emb_dims)
        self.conv5 = nn.Sequential(
            nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False), self.bn5)
        self.linear1 = nn.Linear(args.emb_dims * 2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=args.dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=args.dropout)
        self.linear3 = nn.Linear(256, output_channels)
    def __init__(self, embeddings_dim=1024, output_dim=11, dropout=0.25, kernel_size=7, conv_dropout: float = 0.25):
        super(LogSparseSoftmax, self).__init__()

        self.conv1 = nn.Conv1d(embeddings_dim, embeddings_dim, kernel_size, stride=1, padding=kernel_size // 2)
        self.attend1 = nn.Conv1d(embeddings_dim, embeddings_dim, kernel_size, stride=1, padding=kernel_size // 2)
        self.conv2 = nn.Conv1d(embeddings_dim, embeddings_dim, kernel_size, stride=1, padding=kernel_size // 2)
        self.attend2 = nn.Conv1d(embeddings_dim, embeddings_dim, kernel_size, stride=1, padding=kernel_size // 2)
        self.conv3 = nn.Conv1d(embeddings_dim, embeddings_dim, kernel_size, stride=1, padding=kernel_size // 2)
        self.attend3 = nn.Conv1d(embeddings_dim, embeddings_dim, kernel_size, stride=1, padding=kernel_size // 2)

        self.softmax = nn.Softmax(dim=-1)
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.sparsemax = Sparsemax(dim=-1)

        self.dropout1 = nn.Dropout(conv_dropout)
        self.dropout2 = nn.Dropout(conv_dropout)
        self.dropout3 = nn.Dropout(conv_dropout)

        self.linear = nn.Sequential(
            nn.Linear(4 * embeddings_dim, 32),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.BatchNorm1d(32)
        )

        self.output = nn.Linear(32, output_dim)
Esempio n. 3
0
    def __init__(self, hx_dim, cls_dim, h_dim, num_classes, args):
        super().__init__()

        self.args = args

        self.num_classes = num_classes
        self.in_class = self.num_classes
        self.hdim = h_dim
        self.cls_emb = nn.Embedding(self.in_class, cls_dim)

        in_dim = hx_dim + cls_dim

        self.net = nn.Sequential(
            nn.Linear(in_dim, self.hdim), nn.Tanh(),
            nn.Linear(self.hdim, self.hdim), nn.Tanh(),
            nn.Linear(self.hdim,
                      num_classes + int(self.args.skip),
                      bias=(not self.args.tie)))

        if self.args.sparsemax:
            from sparsemax import Sparsemax
            self.sparsemax = Sparsemax(-1)

        self.init_weights()

        if self.args.tie:
            print('Tying cls emb to output cls weight')
            self.net[-1].weight = self.cls_emb.weight
    def __init__(self, layer_sizes_q, layer_sizes_p, latent_size, conditional,
                 num_labels):

        super().__init__()

        self.num_labels = num_labels

        self.conditional = conditional
        if self.conditional:
            layer_sizes_q[0] += num_labels

        self.MLP_q = nn.Sequential()

        # conv architecture inspired by: https://github.com/rasbt/deeplearning-models/blob/master/pytorch_ipynb/autoencoder/ae-cnn-cvae_no-out-concat.ipynb
        self.MLP_q.add_module(name="C1",
                              module=nn.Conv2d(in_channels=1 + self.num_labels,
                                               out_channels=32,
                                               kernel_size=6,
                                               stride=2,
                                               padding=0))
        self.MLP_q.add_module(name="A1", module=nn.ReLU())
        self.MLP_q.add_module(name="C2",
                              module=nn.Conv2d(in_channels=32,
                                               out_channels=64,
                                               kernel_size=4,
                                               stride=2,
                                               padding=1))
        self.MLP_q.add_module(name="A2", module=nn.ReLU())
        self.MLP_q.add_module(name="C3",
                              module=nn.Conv2d(in_channels=64,
                                               out_channels=128,
                                               kernel_size=2,
                                               stride=2,
                                               padding=1))
        self.MLP_q.add_module(name="A3", module=nn.ReLU())
        self.MLP_q.add_module(name="F", module=View((-1, 128 * 4 * 4)))

        self.MLP_p = nn.Sequential()

        if self.conditional:

            layer_sizes_p[0] = num_labels

            for i, (in_size, out_size) in enumerate(
                    zip(layer_sizes_p[:-1], layer_sizes_p[1:])):
                self.MLP_p.add_module(name="L%i" % (i),
                                      module=nn.Linear(in_size, out_size))
                self.MLP_p.add_module(name="A%i" % (i), module=nn.ReLU())

        self.linear_latent_q = nn.Linear(128 * 4 * 4, latent_size)
        self.softmax_q = nn.Softmax(dim=-1)

        self.linear_latent_p = nn.Linear(layer_sizes_p[-1], latent_size)
        self.softmax_p = nn.Softmax(dim=-1)

        self.sparsemax_p = Sparsemax(dim=-1)
Esempio n. 5
0
    def __init__(self, input_size, output_size, relax_coef=2):
        super(AttentiveTransformer, self).__init__()

        self.fc = torch.nn.Linear(input_size, output_size)
        self.fc_bn = torch.nn.BatchNorm1d(output_size)

        self.prior = None
        self.relax_coef = relax_coef

        self.sparse = Sparsemax()
Esempio n. 6
0
    def __init__(self, H_clusters, H_neurons_per_cluster):
        super(Net, self).__init__()
        self.H_clusters = H_clusters
        self.H_neurons_per_cluster = H_neurons_per_cluster
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, self.H_clusters * self.H_neurons_per_cluster)

        self.sparsemaxActivation = Sparsemax(self.H_clusters,
                                             self.H_neurons_per_cluster)
Esempio n. 7
0
    def __init__(self, step_id, **kwargs):
        super(AttentiveTransformer, self).__init__()
        self.step_id = step_id
        self.params.update(kwargs)

        self.fc = nn.Linear(self.params["n_dims_a"],
                            self.params["n_input_dims"])
        self.bn = nn.BatchNorm1d(
            num_features=self.params["n_input_dims"],
            momentum=self.params["batch_norm_momentum"],
        )
        self.sparsemax = Sparsemax(dim=-1)
Esempio n. 8
0
    def forward(self, input_data):

        input_data = self.fc_bn(self.fc(input_data))

        if self.prior is None:
            self.prior = torch.ones_like(input_data)

        mask = Sparsemax(input_data * self.prior)

        self.prior *= (self.relax_coef - mask)

        return mask
Esempio n. 9
0
 def __init__(self,
              n_features,
              h=None,
              na=None,
              ghost_size=0,
              sparsemax=True):
     super().__init__()
     if h == None:
         h = nn.Linear(na, n_features, bias=False)
     self.h = h
     self.bn = (GhostNorm(n_features, ghost_size)
                if ghost_size else nn.BatchNorm1d(n_features))
     self.sm = Sparsemax() if sparsemax else nn.Softmax()
Esempio n. 10
0
def main():
    params = parser.parse_args()
    params.lowercase = params.lowercase == 'True'
    print(params)
    model_name = 'bert-base-uncased' if params.lowercase else 'bert-base-cased'
    print(model_name)
    src_tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        cache_dir='/home/georgios.vernikos/workspace/LMMT/MonoEgo/cache',
        use_fast=True)
    tgt_tokenizer = BertTokenizerFast(vocab_file=params.tgt_vocab,
                                      do_lower_case=params.lowercase,
                                      strip_accents=False)

    src_embs, src_subwords = get_subword_embeddings(src_tokenizer,
                                                    params.src_aligned_vec,
                                                    params.topn,
                                                    params.lowercase)
    tgt_embs, tgt_subwords = get_subword_embeddings(tgt_tokenizer,
                                                    params.tgt_aligned_vec,
                                                    params.topn,
                                                    params.lowercase)
    src_embs = renorm(src_embs, 1)
    tgt_embs = renorm(tgt_embs, 1)
    # initialize sparse-max

    sparsemax = Sparsemax(1)

    print(f'| # src subwords founds: {len(src_subwords)}')
    print(f'| # tgt subwords founds: {len(tgt_subwords)}')
    print('| compute translation probability')
    scores = tgt_embs @ src_embs.t()
    a = sparsemax(scores)  # (Vf, Ve)
    print('| generating translation table!')
    probs = {}

    for i, tt in tqdm(enumerate(tgt_subwords), total=len(tgt_subwords)):
        probs[tt] = {}
        ix = torch.nonzero(a[i]).view(-1)
        px = a[i][ix].tolist()
        wx = [src_subwords[j] for j in ix.tolist()]
        probs[tt] = {w: p for w, p in zip(wx, px)}
    n_avg = np.mean([len(ss) for ss in probs.values()])
    print(f'| average # source / target: {n_avg:.2f} ')
    print(f"| save translation probabilities: {params.save}")
    torch.save(probs, params.save)
Esempio n. 11
0
    def __init__(self, args, output_channels=40):
        super(PaiNet, self).__init__()
        self.args = args
        self.k = args.k
        num_kernel = 9  # xyz*3 + 1
        self.activation = nn.LeakyReLU(negative_slope=0.2)

        map_size = 32
        num_bases = 16
        self.B = nn.Parameter(torch.randn(7 * self.k, map_size),
                              requires_grad=False)
        self.mlp = nn.Linear(map_size * 2, num_bases, bias=False)
        self.permatrix = nn.Parameter(torch.randn(num_bases, self.k,
                                                  num_kernel),
                                      requires_grad=True)
        self.permatrix.data = torch.cat([
            torch.eye(num_kernel),
            torch.zeros(self.k - num_kernel, num_kernel)
        ],
                                        dim=0).unsqueeze(0).expand_as(
                                            self.permatrix)
        self.softmax = Sparsemax(dim=-1)

        self.conv1 = PaiConv(3, 64, self.k, num_kernel)
        self.conv2 = PaiConv(64, 64, self.k, num_kernel)
        self.conv3 = PaiConv(64, 128, self.k, num_kernel)
        self.conv4 = PaiConv(128, 256, self.k, num_kernel)

        self.bn5 = nn.BatchNorm1d(args.emb_dims)
        self.conv5 = nn.Sequential(
            nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False), self.bn5)

        self.linear1 = nn.Linear(args.emb_dims * 2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=args.dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=args.dropout)
        self.linear3 = nn.Linear(256, output_channels)
Esempio n. 12
0
def main():
    params = parser.parse_args()
    print(params)
    src_tokenizer = BertTokenizer(params.src_vocab, do_lower_case=False)
    tgt_tokenizer = BertTokenizer(params.tgt_vocab, do_lower_case=False)

    src_embs, src_subwords = get_subword_embeddings(src_tokenizer,
                                                    params.src_aligned_vec,
                                                    params.topn)
    tgt_embs, tgt_subwords = get_subword_embeddings(tgt_tokenizer,
                                                    params.tgt_aligned_vec,
                                                    params.topn)
    src_embs = renorm(src_embs, 1)
    tgt_embs = renorm(tgt_embs, 1)
    # initialize sparse-max

    sparsemax = Sparsemax(1)

    print(f'| # src subwords founds: {len(src_subwords)}')
    print(f'| # tgt subwords founds: {len(tgt_subwords)}')
    print('| compute translation probability')
    scores = tgt_embs @ src_embs.t()
    a = sparsemax(scores)  # (Vf, Ve)
    print('| generating translation table!')
    probs = {}

    for i, tt in tqdm(enumerate(tgt_subwords), total=len(tgt_subwords)):
        probs[tt] = {}
        ix = torch.nonzero(a[i]).view(-1)
        px = a[i][ix].tolist()
        wx = [src_subwords[j] for j in ix.tolist()]
        probs[tt] = {w: p for w, p in zip(wx, px)}
    n_avg = np.mean([len(ss) for ss in probs.values()])
    print(f'| average # source / target: {n_avg:.2f} ')
    print(f"| save translation probabilities: {params.save}")
    torch.save(probs, params.save)
    def __init__(self, layer_sizes_q, layer_sizes_p, latent_size, conditional,
                 num_labels):

        super().__init__()

        self.conditional = conditional
        if self.conditional:
            layer_sizes_q[0] += num_labels

        self.MLP_q = nn.Sequential()

        for i, (in_size, out_size) in enumerate(
                zip(layer_sizes_q[:-1], layer_sizes_q[1:])):
            self.MLP_q.add_module(name="L%i" % (i),
                                  module=nn.Linear(in_size, out_size))
            self.MLP_q.add_module(name="A%i" % (i), module=nn.ReLU())

        self.MLP_p = nn.Sequential()

        if self.conditional:

            layer_sizes_p[0] = num_labels

            for i, (in_size, out_size) in enumerate(
                    zip(layer_sizes_p[:-1], layer_sizes_p[1:])):
                self.MLP_p.add_module(name="L%i" % (i),
                                      module=nn.Linear(in_size, out_size))
                self.MLP_p.add_module(name="A%i" % (i), module=nn.ReLU())

        self.linear_latent_q = nn.Linear(layer_sizes_q[-1], latent_size)
        self.softmax_q = nn.Softmax(dim=-1)

        self.linear_latent_p = nn.Linear(layer_sizes_p[-1], latent_size)
        self.softmax_p = nn.Softmax(dim=-1)

        self.sparsemax_p = Sparsemax(dim=-1)
Esempio n. 14
0
def test_sparsemax_invalid_dimension():
    sparsemax = Sparsemax(-7)
    input = torch.randn(6, 3, 5, 4, dtype=torch.double, requires_grad=True)
    with pytest.raises(IndexError):
        gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
Esempio n. 15
0
def test_sparsemax(dimension):
    sparsemax = Sparsemax(dimension)
    input = torch.randn(6, 3, 5, 4, dtype=torch.double, requires_grad=True)
    assert gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
Esempio n. 16
0
#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
import torch
import torch.nn as nn
import numpy as np

from sparsemax import Sparsemax
sparsemax = Sparsemax(dim=1)

# GLU 
def glu(act, n_units):
    
    act[:, :n_units] = act[:, :n_units].clone() * torch.nn.Sigmoid()(act[:, n_units:].clone())     
    
    return act

class TabNetModel(nn.Module):
    
    def __init__(
        self,
        columns = 3,
        num_features = 3,
        feature_dims = 128,
        output_dim  =64,
        num_decision_steps =6,
        relaxation_factor = 0.5,
        batch_momentum = 0.001,
        virtual_batch_size = 2,
from torch import nn
import copy
import torch
from torchsummary import summary
import numpy as np
import torch.nn.functional as F
from sparsemax import Sparsemax
import sys
sys.path.append('../NCKHECG')
from ModelArch.BackBone import SingleBackBoneNet, MultiBackBoneNet
sparsemax = Sparsemax(dim=-1)


class PrototypeNet(nn.Module):
    def __init__(self,
                 SingleBackBone,
                 attention_dim,
                 feature_dim,
                 hidden_dim,
                 class_hidden_dim,
                 input_channel=1,
                 num_class=4):
        super(PrototypeNet, self).__init__()
        self.attention_dim = attention_dim
        self.BackBone = SingleBackBone
        self.proj_feature = nn.Linear(in_features=feature_dim,
                                      out_features=hidden_dim)
        self.layernorm = nn.LayerNorm(hidden_dim)
        self.encoded_key = nn.Linear(in_features=hidden_dim,
                                     out_features=attention_dim)
        self.encoded_querry = nn.Linear(in_features=hidden_dim,
def main():

    random = 123

    dict = pkl.load(open('../../dst_test.pkl', "rb"))
    print("weight", dict['weight'].shape)
    print("z", dict['z_one_hot'].shape)
    print("x", len(dict['x']), dict['x'][0].shape)
    print("alpha_p", dict['alpha_p'].data.shape)
    print("features", dict['features'].shape)
    print("bias", dict['bias'])
    print("c", len(dict['c']), dict['c'][0][0].shape, dict['c'][1][0].shape,
          len(dict['c'][1][1]))
    print("gt", len(dict['gt']), dict['gt'][0].shape, dict['gt'][1].shape)
    batch_size = dict['features'].shape[0]

    K = 5

    m_Z = []
    keep_nums = {}

    if not os.path.exists('results/'):
        os.makedirs('results/')

    # initialize filtered alpha p
    filtered_alpha_p = np.zeros((batch_size, 2, 5))
    filtered_alpha_p_sparsemax = np.zeros((batch_size, 2, 5))

    for batch in tqdm(range(batch_size)):

        # [N, J]
        features = np.expand_dims(
            dict['features'].data.cpu().numpy()[batch, :], axis=0)
        # [25,12,2]
        x = dict['x'][batch]
        # [9,6]
        c = dict['c'][batch][0]
        # [12, 2]
        gt = dict['gt'][batch]

        num_predictions = gt.shape[0]

        filtered_mask = np.zeros((x.shape[0]))
        filtered_mask_sparsemax = np.zeros((x.shape[0]))

        for num in range(2):

            indices = range(num, num + 1)

            # [J, K]
            weights = np.reshape(dict['weight'], (32, 2, 5))[:, num, :]
            # [K, 1]
            bias = np.expand_dims(np.reshape(dict['bias'], (2, 5))[num, :],
                                  axis=-1)
            # [N, K]
            alpha_p = dict['alpha_p'][batch, indices, :]
            # [25, N, K]
            z_one_hot = np.reshape(dict['z_one_hot'],
                                   (25, batch_size, 2, 5))[:, batch,
                                                           indices, :]
            # [25, 1]
            z = np.argmax(z_one_hot, axis=-1)

            print("weights", weights.shape)
            print("bias", bias.shape)
            print("features", features.shape)
            print("x", x.shape)
            print("c", c.shape)
            print("alpha_p", alpha_p.shape)
            print("z_one_hot", z.shape)
            print("gt", gt.shape)

            # Compute the sparsemax baseline
            sparsemax = Sparsemax(dim=-1)
            sparsemax_logits = (weights.T.dot(features.T) + bias).flatten()
            filtered_alpha_p_sparsemax_torch = sparsemax(
                torch.from_numpy(sparsemax_logits).cuda())
            print('logits', sparsemax_logits)

            # Populate the filtered sparsemax distribution
            filtered_alpha_p_sparsemax[
                batch, num, :] = filtered_alpha_p_sparsemax_torch.data.cpu(
                ).numpy().flatten()
            print("filtered alpha p sparsemax",
                  filtered_alpha_p_sparsemax[batch])

            indices_filtered_sparsemax = np.where(
                filtered_alpha_p_sparsemax[batch, num, :] == 0)
            print('filtered indices', indices_filtered_sparsemax[0].shape,
                  filtered_mask_sparsemax.shape)
            for j in range(indices_filtered_sparsemax[0].shape[0]):
                filtered_mask_sparsemax[z.flatten() ==
                                        indices_filtered_sparsemax[0][j]] = 1.

            # Compute DST filter
            mean = features.flatten()

            dst_obj = DST()
            dst_obj.weights_from_linear_layer(weights, bias, features, mean)
            dst_obj.get_output_mass(num_classes=K)

            m_Z.append(dst_obj.output_mass[tuple(range(K))])

            print('sum of singletons',
                  sum(dst_obj.output_mass_singletons.flatten()))

            norm_singletons = deepcopy(alpha_p)
            norm_singletons[dst_obj.output_mass_singletons == 0.] = 0.
            norm_singletons = norm_singletons / np.sum(norm_singletons)

            indices_filtered = np.where(norm_singletons[0] == 0)
            print('filtered indices', indices_filtered[0].shape, z.shape,
                  filtered_mask.shape)
            for j in range(indices_filtered[0].shape[0]):
                filtered_mask[z.flatten() == indices_filtered[0][j]] = 1.

            # save the filtered alpha_p
            filtered_alpha_p[batch, num, :] = norm_singletons.flatten()
            print("filtered alpha p", filtered_alpha_p[batch])

            if num == 1:

                plt.figure()
                width = 0.5
                p1 = plt.bar(np.arange(K),
                             alpha_p.flatten(),
                             width,
                             color='blue',
                             alpha=0.5)
                p2 = plt.bar(np.arange(K),
                             filtered_alpha_p_sparsemax[batch,
                                                        num, :].flatten(),
                             width,
                             color='orange',
                             alpha=0.5)
                p3 = plt.bar(
                    np.arange(K),
                    norm_singletons.flatten(),
                    width,
                    color='green',
                    alpha=0.5
                )  # /1./np.sum(dst_obj.output_mass_singletons.flatten())

                plt.xlabel('Z')
                plt.ylabel('Values')
                plt.ylim(0, 1.0)
                plt.title('Values for Odd')
                plt.legend(['Probabilities', 'Singleton Masses'])
                plt.savefig(
                    'results/odd_dec_z_epochs_20_latent_10_p_30_filtered_prob_random_'
                    + str(random) + '_old_batch_' + str(batch) + '.png',
                    dpi=600)
                matplotlib2tikz.save(
                    'results/odd_dec_z_epochs_20_latent_10_p_30_filtered_prob_random_'
                    + str(random) + '_old_' + str(batch) + '.tex')
                plt.close()

                print("odd evidential weights pos k",
                      dst_obj.evidential_weights_pos_k)
                print("odd evidential weights neg k",
                      dst_obj.evidential_weights_neg_k)

            if num == 0:

                plt.figure()
                width = 0.5
                p1 = plt.bar(np.arange(K),
                             alpha_p.flatten(),
                             width,
                             color='blue',
                             alpha=0.5)
                p2 = plt.bar(np.arange(K),
                             filtered_alpha_p_sparsemax[batch,
                                                        num, :].flatten(),
                             width,
                             color='orange',
                             alpha=0.5)
                p3 = plt.bar(
                    np.arange(K),
                    norm_singletons.flatten(),
                    width,
                    color='green',
                    alpha=0.5
                )  # /1./np.sum(dst_obj.output_mass_singletons.flatten())

                plt.xlabel('Z')
                plt.ylabel('Values')
                plt.ylim(0, 1.0)
                plt.title('Values for Even')
                plt.legend(['Probabilities', 'Singleton Masses'])
                plt.savefig(
                    'results/even_dec_z_epochs_20_latent_10_p_30_filtered_prob_random_'
                    + str(random) + '_old_' + str(batch) + '.png',
                    dpi=600)
                matplotlib2tikz.save(
                    'results/even_dec_z_epochs_20_latent_10_p_30_filtered_prob_random_'
                    + str(random) + '_old_' + str(batch) + '.tex')
                plt.close()

            # print("labels", c)
            print("singletons", dst_obj.output_mass_singletons)
            print("filtered probabilities", norm_singletons)
            print("probabilities", alpha_p)

        # plot the trajectories filtered out
        # [25, N, K]
        z_one_hot = np.reshape(dict['z_one_hot'],
                               (25, batch_size, 2, 5))[:, batch, :, :]
        z = np.argmax(z_one_hot, axis=-1)
        print("z mask", filtered_mask, np.sum(filtered_mask))
        print("check", np.where(filtered_mask))

        keep_mask = 1 - filtered_mask
        keep_mask_sparsemax = 1 - filtered_mask_sparsemax

        color_index_blue = np.linspace(0, 1, 25)
        color_index_green = np.linspace(0, 1, np.sum(keep_mask))

        counter = 0
        for i in range(x.shape[0]):
            plt.plot(x[i, :num_predictions, 0],
                     x[i, :num_predictions, 1],
                     color='blue',
                     alpha=0.15)
            counter += 1

        counter = 0
        for i in np.where(keep_mask)[0]:
            plt.plot(x[i, :num_predictions, 0],
                     x[i, :num_predictions, 1],
                     color='green')
            counter += 1

        counter = 0
        for i in np.where(keep_mask_sparsemax)[0]:
            plt.plot(x[i, :num_predictions, 0],
                     x[i, :num_predictions, 1],
                     color='orange',
                     alpha=0.5)
            counter += 1

        if not os.path.exists('figures_paper/'):
            os.makedirs('figures_paper/')

        plt.plot(gt[:, 0], gt[:, 1], color='black')
        plt.plot(c[:, 0], c[:, 1], '.', color='gray')
        plt.xlabel('x (m)')
        plt.ylabel('y (m)')
        plt.savefig('figures_paper/keep_predictions_new_' + str(batch) +
                    '.png')
        matplotlib2tikz.save('figures_paper/keep_predictions_new_tikz_' +
                             str(batch) + '.tex')
        plt.close()

        # count the different filtered dimensions
        if np.sum(keep_mask) in keep_nums.keys():
            keep_nums[np.sum(keep_mask)] += 1
        else:
            keep_nums[np.sum(keep_mask)] = 1

    pkl.dump(filtered_alpha_p, open("filtered_alpha_p.pkl", "wb"))
    pkl.dump(filtered_alpha_p_sparsemax,
             open("filtered_alpha_p_sparsemax.pkl", "wb"))
    print("keep nums", keep_nums)