예제 #1
0
파일: solver.py 프로젝트: heewookl/TCR
    def reparameterize(self, p_pep, p_tcr, tau, k, num_sample):

        # sampling
        batch_size = p_pep.size(0)
        len_pep = p_pep.size(1) # batch_size * len_pep
        len_tcr = p_tcr.size(1) # batch_size * len_tcr
        p_pep_ = p_pep.view(batch_size, 1, 1, len_pep).expand(batch_size, num_sample, k, len_pep)
        p_tcr_ = p_tcr.view(batch_size, 1, 1, len_tcr).expand(batch_size, num_sample, k, len_tcr)
        C_pep = RelaxedOneHotCategorical(tau, p_pep_)
        C_tcr = RelaxedOneHotCategorical(tau, p_tcr_)
        Z_pep, _ = torch.max(C_pep.sample(), -2) # batch_size, num_sample, len_pep
        Z_tcr, _ = torch.max(C_tcr.sample(), -2) # batch_size, num_sample, len_tcr
        
        # without sampling
        _, Z_fixed_pep = p_pep.topk(k, dim = -1) # batch_size, k
        _, Z_fixed_tcr = p_tcr.topk(k, dim = -1) # batch_size, k
        size_pep = p_pep.size()
        size_tcr = p_tcr.size()
        Z_fixed_pep = idxtobool(Z_fixed_pep, size_pep, self.cuda)
        Z_fixed_tcr = idxtobool(Z_fixed_tcr, size_tcr, self.cuda)

        return Z_pep, Z_tcr, Z_fixed_pep, Z_fixed_tcr
예제 #2
0
파일: explainer.py 프로젝트: tom-1221/VIBI
    def reparameterize(self, p_i, tau, k, num_sample=1):

        ## sampling
        p_i_ = p_i.view(p_i.size(0), 1, 1, -1)
        p_i_ = p_i_.expand(p_i_.size(0), num_sample, k, p_i_.size(-1))
        C_dist = RelaxedOneHotCategorical(tau, p_i_)
        V = torch.max(C_dist.sample(), -2)[0]  # [batch-size, multi-shot, d]

        ## without sampling
        V_fixed_size = p_i.unsqueeze(1).size()
        _, V_fixed_idx = p_i.unsqueeze(1).topk(k, dim=-1)  # batch * 1 * k
        V_fixed = idxtobool(V_fixed_idx, V_fixed_size, is_cuda=self.args.cuda)
        V_fixed = V_fixed.type(torch.float)

        return V, V_fixed
예제 #3
0
 def forward(self, inputs, lengths, temp=None, y=None):
     enc_emb = self.lookup(inputs)
     dec_emb = self.lookup(inputs)
     hn = self.encoder(enc_emb, lengths)
     py = self.classifier(hn)
     if y is None:
         dist = RelaxedOneHotCategorical(temp, logits=py)
         y = dist.sample().max(1)[1]
     y_emb = self.y_lookup(y)
     h = torch.cat([hn, y_emb.unsqueeze(0)], dim=2)
     mu, logvar = self.fcmu(h), self.fclogvar(h)
     if self.training:
         z = self.reparameterize(mu, logvar)
     else:
         z = mu
     code = torch.cat([z, y_emb.unsqueeze(0)], dim=2)
     outputs, _ = self.decoder(dec_emb, code, lengths=lengths)
     outputs = self.fcout(outputs)
     bow = self.bow_predictor(code)
     return outputs, mu, logvar, bow, py
needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
weights = torch.softmax(needsoftmax_mixtureweight, dim=0)



# cat = Categorical(probs= weights)
avg_samp = torch.zeros(n_cats)
grads = []
logprobgrads = []
momem = 0
for step in range(max_steps):

    weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
    cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.]))
    cluster_S = cat.sample()
    logprob = cat.log_prob(cluster_S.detach())
    cluster_H = H(cluster_S) #

    logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(needsoftmax_mixtureweight), retain_graph=False)[0]

    one_hot = torch.zeros(n_cats)
    one_hot[cluster_H] = 1.

    f_val = f(one_hot)

    # grad = f_val * logprobgrad
    # needsoftmax_mixtureweight = needsoftmax_mixtureweight + lr*grad

    grad = f_val * logprobgrad
    momem += grad*.1
예제 #5
0
    L2_losses = []
    steps_list = []
    for step in range(n_steps):

        optim.zero_grad()

        loss = 0
        net_loss = 0
        for i in range(batch_size):
            x = sample_true()
            logits = encoder.net(x)
            # print (logits.shape)
            # print (torch.softmax(logits, dim=0))
            # fsfd
            cat = Categorical(probs= torch.softmax(logits, dim=0))
            cluster = cat.sample()
            logprob_cluster = cat.log_prob(cluster.detach())
            # print (logprob_cluster)
            pxz = logprob_undercomponent(x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False)
            f = pxz - logprob_cluster
            # print (f)
            # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight)
            net_loss += -f.detach() * logprob_cluster
            loss += -f
        loss = loss / batch_size
        net_loss = net_loss / batch_size

        # print (loss, net_loss)

        loss.backward(retain_graph=True)  
        optim.step()
aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
aaa = torch.tensor(aaa).float()
logprob = cat.log_prob(aaa)

logprob = to_print2(logprob)
logprob = logprob.reshape(X.shape)
prob = np.exp(logprob)
cs = plt.contourf(X, Y, prob, cmap=cmap, alpha=alpha)

# ax.text(0,0,str(to_print2(weights)))

samps0 = []
samps1 = []
samps = []
for i in range (300):
    samp = cat.sample()
    samp =to_print2(samp)
    if samp[0]>samp[1]:
        samps0.append(samp)
    else:
        samps1.append(samp)
    samps.append(samp)
    # print(samp)
samps0 = np.array(samps0)
samps1 = np.array(samps1)
samps = np.array(samps)

print (len(samps0)/len(samps))
print (len(samps1)/len(samps))
# print(samps.shape)
aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
aaa = torch.tensor(aaa).float()
logprob = cat.log_prob(aaa)

logprob = to_print2(logprob)
logprob = logprob.reshape(X.shape)
prob = np.exp(logprob)
cs = plt.contourf(X, Y, prob, cmap=cmap, alpha=alpha)

# ax.text(0,0,str(to_print2(weights)))

samps0 = []
samps1 = []
samps = []
for i in range(300):
    samp = cat.sample()
    samp = to_print2(samp)
    if samp[0] > samp[1]:
        samps0.append(samp)
    else:
        samps1.append(samp)
    samps.append(samp)
    # print(samp)
samps0 = np.array(samps0)
samps1 = np.array(samps1)
samps = np.array(samps)

print(len(samps0) / len(samps))
print(len(samps1) / len(samps))
# print(samps.shape)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

# cat = Categorical(probs= weights)
avg_samp = torch.zeros(n_cats)
grads = []
logprobgrads = []
momem = 0
for step in range(max_steps):

    weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
    cat = RelaxedOneHotCategorical(probs=weights,
                                   temperature=torch.tensor([1.]))
    cluster_S = cat.sample()
    logprob = cat.log_prob(cluster_S.detach())
    cluster_H = H(cluster_S)  #

    logprobgrad = torch.autograd.grad(outputs=logprob,
                                      inputs=(needsoftmax_mixtureweight),
                                      retain_graph=False)[0]

    one_hot = torch.zeros(n_cats)
    one_hot[cluster_H] = 1.

    f_val = f(one_hot)

    # grad = f_val * logprobgrad
    # needsoftmax_mixtureweight = needsoftmax_mixtureweight + lr*grad