Exemplo n.º 1
0
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(train,
                                           batch_size=mb_size,
                                           shuffle=True,
                                           **kwargs)
val_loader = torch.utils.data.DataLoader(test,
                                         batch_size=mb_size,
                                         shuffle=True,
                                         **kwargs)
test_loader = torch.utils.data.DataLoader(test,
                                          batch_size=mb_size,
                                          shuffle=False,
                                          **kwargs)

contrastiveloss = contrastive.ContrastiveLoss(margin=1.0)
#KLloss = contrastive.KL_avg_sigma()

enc = net.VariationalEncoder(dim=z_dim)
#enc = net.Encoder(dim=z_dim)
#dec = net.Decoder(output_dim=(28, 28))
dec = net.Decoder(dim=z_dim)

if use_cuda:
    enc.cuda()
    dec.cuda()


def reset_grad():
    enc.zero_grad()
    dec.zero_grad()
Exemplo n.º 2
0
def train(rawdata, charcounts, maxlens, unique_onehotvals):
    mb_size = 256
    lr = 2.0e-4
    cnt = 0
    latent_dim = 32
    recurrent_hidden_size = 24

    epoch_len = 8
    max_veclen = 0.0
    patience = 12 * epoch_len
    patience_duration = 0

    # mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

    input_dict = {}
    input_dict['discrete'] = discrete_cols
    input_dict['continuous'] = continuous_cols

    input_dict['onehot'] = {}
    for k in onehot_cols:
        dim = int(np.ceil(np.log(len(unique_onehotvals[k])) / np.log(2.0)))
        input_dict['onehot'][k] = dim

    if len(charcounts) > 0:
        text_dim = int(np.ceil(np.log(len(charcounts)) / np.log(2.0)))
        input_dict['text'] = {t: text_dim for t in text_cols}
    else:
        text_dim = 0
        input_dict['text'] = {}

    data = Dataseq(rawdata, charcounts, input_dict, unique_onehotvals, maxlens)
    data_idx = np.arange(data.__len__())
    np.random.shuffle(data_idx)
    n_folds = 6
    fold_size = 1.0 * data.__len__() / n_folds
    folds = [data_idx[int(i * fold_size):int((i + 1) * fold_size)] for i in range(6)]

    fold_groups = {}
    fold_groups[0] = {'train': [0, 1, 2, 4], 'es': [3], 'val': [5]}
    fold_groups[1] = {'train': [0, 2, 3, 5], 'es': [1], 'val': [4]}
    fold_groups[2] = {'train': [1, 3, 4, 5], 'es': [2], 'val': [0]}
    fold_groups[3] = {'train': [0, 2, 3, 4], 'es': [5], 'val': [1]}
    fold_groups[4] = {'train': [0, 1, 3, 5], 'es': [4], 'val': [2]}
    fold_groups[5] = {'train': [1, 2, 4, 5], 'es': [0], 'val': [3]}

    for fold in range(1):

        train_idx = np.array(list(itertools.chain.from_iterable([folds[i] for i in fold_groups[fold]['train']])))
        es_idx = np.array(list(itertools.chain.from_iterable([folds[i] for i in fold_groups[fold]['es']])))
        val_idx = np.array(folds[fold_groups[fold]['val'][0]])

        train = Subset(data, train_idx)
        es = Subset(data, es_idx)
        val = Subset(data, val_idx)

        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        train_iter = torch.utils.data.DataLoader(train, batch_size=mb_size, shuffle=True, **kwargs)
        es_iter = torch.utils.data.DataLoader(es, batch_size=mb_size, shuffle=True, **kwargs)
        val_iter = torch.utils.data.DataLoader(val, batch_size=mb_size, shuffle=True, **kwargs)

        embeddings = {}
        reverse_embeddings = {}
        onehot_embedding_weights = {}
        for k in onehot_cols:
            dim = input_dict['onehot'][k]
            onehot_embedding_weights[k] = net.get_embedding_weight(len(unique_onehotvals[k]), dim)
            if use_cuda:
                onehot_embedding_weights[k] = onehot_embedding_weights[k].cuda()
            #embeddings[k] = nn.Embedding(len(unique_onehotvals[k]), dim, max_norm=1.0)
            embeddings[k] = nn.Embedding(len(unique_onehotvals[k]), dim, _weight=onehot_embedding_weights[k], max_norm=1.0)
            reverse_embeddings[k] = net.EmbeddingToIndex(len(unique_onehotvals[k]), dim, _weight=onehot_embedding_weights[k])

        if text_dim > 0:
            text_embedding_weights = net.get_embedding_weight(len(charcounts) + 1, text_dim)
            if use_cuda:
                text_embedding_weights = text_embedding_weights.cuda()
            #text_embedding = nn.Embedding(len(charcounts)+1, text_dim, max_norm=1.0)
            text_embedding = nn.Embedding(len(charcounts) + 1, text_dim, _weight=text_embedding_weights, max_norm=1.0)
            text_embeddingtoindex = net.EmbeddingToIndex(len(charcounts) + 1, text_dim, _weight=text_embedding_weights)
            for k in text_cols:
                embeddings[k] = text_embedding
                reverse_embeddings[k] = text_embeddingtoindex

        enc = net.Encoder(input_dict, dim=latent_dim, recurrent_hidden_size=recurrent_hidden_size)
        dec = net.Decoder(input_dict, maxlens, dim=latent_dim, recurrent_hidden_size=recurrent_hidden_size)

        if use_cuda:
            embeddings = {k: embeddings[k].cuda() for k in embeddings.keys()}
            reverse_embeddings = {k: reverse_embeddings[k].cuda() for k in embeddings.keys()}
            enc.cuda()
            dec.cuda()


        #print(enc.parameters)
        #print(dec.parameters)


        contrastivec = contrastive.ContrastiveLoss(margin=margin)


        #solver = optim.RMSprop([p for em in embeddings.values() for p in em.parameters()] +  [p for p in enc.parameters()] + [p for p in dec.parameters()], lr=lr)
        solver = optim.Adam(
            [p for em in embeddings.values() for p in em.parameters()] + [p for p in enc.parameters()] + [p for p in
                                                                                                          dec.parameters()],
            lr=lr)

        Tsample = next(es_iter.__iter__())
        if use_cuda:
            Tsample = {col: Variable(tt[0:128]).cuda() for col, tt in Tsample.items()}
        else:
            Tsample = {col: Variable(tt[0:128]) for col, tt in Tsample.items()}

        print({col: tt[0] for col, tt in Tsample.items()})

        print('starting training')
        loss = 0.0
        for it in range(1000000):
            # X = Variable(torch.tensor(np.array([[1,2,4], [4,1,9]]))).cuda()
            batch_idx, T = next(enumerate(train_iter))
            if use_cuda:
                T = {col: Variable(tt).cuda() for col, tt in T.items()}
            else:
                T = {col: Variable(tt) for col, tt in T.items()}

            X = {}
            for col, tt in T.items():
                if col in embeddings.keys():
                    X[col] = embeddings[col](tt)
                else:
                    X[col] = tt.float()

            mu = enc(X)
            X2 = dec(mu)

            T2 = {}
            X2d = {col: (1.0 * tt).detach() for col, tt in X2.items()}


            for col, embedding in embeddings.items():
                T2[col] = reverse_embeddings[col](X2[col])
                X2[col] = 0.5*X2[col] + 0.5*embeddings[col](T2[col])
                X2d[col] = embeddings[col](T2[col].detach())



            '''
            X2d = {col: (1.0*tt).detach() for col, tt in X2.items()}
            T2 = discretize(X2d, embeddings, maxlens)
            for col, embedding in embeddings.items():
                X2d[col] = embeddings[col](T2[col].detach())
            '''
            '''
            T2 = discretize(X2, embeddings, maxlens)
            X2d = {col: (1.0*tt).detach() for col, tt in X2.items()}

            for col, embedding in embeddings.items():
                X2[col] = embeddings[col](T2[col]) #+0.05 X2[col]
                X2d[col] = embeddings[col](T2[col].detach())
            '''


            mu2 = enc(X2)
            mu2 = mu2.view(mb_size, -1)

            mu2d = enc(X2d)

            mu2d = mu2d.view(mb_size, -1)


            mu = mu.view(mb_size, -1)

            are_same = are_equal({col: x[::2] for col, x in T.items()}, {col: x[1::2] for col, x in T.items()})
            #print('f same ', torch.mean(torch.mean(are_same, 1)))
            #enc_loss = contrastivec(mu2[::2], mu2[1::2], torch.zeros(int(mb_size / 2)).cuda())
            enc_loss = contrastivec(mu[::2], mu[1::2], are_same)
            #enc_loss += 0.5*contrastivec(mu2[::2], mu2[1::2], are_same)
            #enc_loss += 0.5 * contrastivec(mu[::2], mu2[1::2], are_same)
            enc_loss += 1.0*contrastivec(mu, mu2, torch.ones(mb_size).cuda())
            enc_loss += 2.0*contrastivec(mu, mu2d, torch.zeros(mb_size).cuda())
            #enc_loss += 1.0 * contrastivec(mu2d[0::2], mu2d[1::2], torch.ones(int(mb_size/2)).cuda())
            #enc_loss += 1.0 * contrastivec(mu2d[::2], mu2d[1::2], torch.ones(int(mb_size / 2)).cuda())
            #enc_loss += 0.5 * contrastivec(mu2d[::2], mu2d[1::2], torch.ones(int(mb_size/2)).cuda())

            '''
            adotb = torch.matmul(mu, mu.permute(1, 0))  # batch_size x batch_size
            adota = torch.matmul(mu.view(-1, 1, latent_dim), mu.view(-1, latent_dim, 1))  # batch_size x 1 x 1
            diffsquares = (adota.view(-1, 1).repeat(1, mb_size) + adota.view(1, -1).repeat(mb_size, 1) - 2 * adotb) / latent_dim

            # did I f**k up something here? diffsquares can apparently be less than 0....
            mdist = torch.sqrt(torch.clamp(torch.triu(diffsquares, diagonal=1),  min=0.0))
            mdist = torch.clamp(margin - mdist, min=0.0)
            number_of_pairs = mb_size * (mb_size - 1) / 2

            enc_loss = 0.5 * torch.sum(torch.triu(torch.pow(mdist, 2), diagonal=1)) / number_of_pairs

            target = torch.ones(mu.size(0), 1)
            if use_cuda:
                target.cuda()
            enc_loss += contrastivec(mu, mu2, target.cuda())

            target = torch.zeros(mu.size(0), 1)
            if use_cuda:
                target.cuda()
            enc_loss += 2.0 * contrastivec(mu, mu2d, target.cuda())
            '''


            enc_loss.backward()
            solver.step()

            enc.zero_grad()
            dec.zero_grad()
            for col in embeddings.keys():
                embeddings[col].zero_grad()

            loss += enc_loss.data.cpu().numpy()
            veclen = torch.mean(torch.pow(mu, 2))
            if it % epoch_len == 0:
                print(it, loss/epoch_len, veclen.data.cpu().numpy()) #enc_loss.data.cpu().numpy(),

                Xsample = {}
                for col, tt in Tsample.items():
                    if col in embeddings.keys():
                        Xsample[col] = embeddings[col](tt)
                    else:
                        Xsample[col] = tt.float()

                mu = enc(Xsample)
                X2sample = dec(mu)
                X2sampled = {col: tt.detach() for col, tt in X2sample.items()}
                T2sample = discretize(X2sample, embeddings, maxlens)

                mu2 = enc(X2sample)
                mu2d = enc(X2sampled)


                if 'Fare' in continuous_cols and 'Age' in continuous_cols:
                    print([np.mean(np.abs(Xsample[col].data.cpu().numpy()-X2sample[col].data.cpu().numpy())) for col in ['Fare', 'Age']])

                print({col: tt[0:2].data.cpu().numpy() for col, tt in T2sample.items()})

                if 'Survived' in onehot_cols:
                    print('% survived correct: ', np.mean(T2sample['Survived'].data.cpu().numpy()==Tsample['Survived'].data.cpu().numpy()), np.mean(Tsample['Survived'].data.cpu().numpy()==np.ones_like(Tsample['Survived'].data.cpu().numpy())))

                if 'Cabin' in text_cols:
                    print(embeddings['Cabin'].weight[data.charindex['1']])



                are_same = are_equal({col: x[::2] for col, x in Tsample.items()}, {col: x[1::2] for col, x in Tsample.items()})
                # print('f same ', torch.mean(torch.mean(are_same, 1)))
                # enc_loss = contrastivec(mu2[::2], mu2[1::2], torch.zeros(int(mb_size / 2)).cuda())
                #es_loss = contrastivec(mu[::2], mu[1::2], are_same)
                # enc_loss += 0.25*contrastivec(mu2[::2], mu2[1::2], are_same)
                # enc_loss += 0.5 * contrastivec(mu[::2], mu2[1::2], are_same)
                es_loss = 1.0 * contrastivec(mu, mu2, torch.ones(mu.size(0)).cuda())
                #es_loss += 2.0 * contrastivec(mu, mu2d, torch.zeros(mu.size(0)).cuda())

                #print('mean mu ', torch.mean(torch.pow(mu, 2)))
                print('es loss ', es_loss)

                loss = 0.0
                #print(T2.data.cpu()[0, 0:30].numpy())
Exemplo n.º 3
0
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(train,
                                           batch_size=mb_size,
                                           shuffle=True,
                                           **kwargs)
val_loader = torch.utils.data.DataLoader(test,
                                         batch_size=mb_size,
                                         shuffle=True,
                                         **kwargs)
test_loader = torch.utils.data.DataLoader(test,
                                          batch_size=mb_size,
                                          shuffle=False,
                                          **kwargs)

contrastiveloss = contrastive.ContrastiveLoss()
#KLloss = contrastive.KL()

#enc = net.VariationalEncoder(dim=z_dim)
enc = net.Encoder(dim=z_dim)
#dec = net.Decoder(output_dim=(28, 28))
dec = net.Decoder(dim=z_dim)

if use_cuda:
    enc.cuda()
    dec.cuda()


def reset_grad():
    enc.zero_grad()
    dec.zero_grad()