Example #1
0
    def get_loss():

        x = sample_true(batch_size).cuda()  #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits / 100., dim=1)
        cat = RelaxedOneHotCategorical(probs=probs,
                                       temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size, 1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(
            x,
            component=cluster_H,
            needsoftmax_mixtureweight=needsoftmax_mixtureweight,
            cuda=True)
        f = logpxz  # - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1)  #[B,21]
        surr_pred = surrogate.net(surr_input)

        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        return surr_loss
 def sample_simplax(probs):
     dist = RelaxedOneHotCategorical(probs=probs,
                                     temperature=torch.Tensor([1.]))
     z = dist.rsample()
     logprob = dist.log_prob(z)
     b = torch.argmax(z, dim=1)
     return z, b, logprob
Example #3
0
def show_surr_preds():

    batch_size = 1

    rows = 3
    cols = 1
    fig = plt.figure(figsize=(10 + cols, 4 + rows),
                     facecolor='white')  #, dpi=150)

    for i in range(rows):

        x = sample_true(1).cuda()  #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits / 100., dim=1)
        cat = RelaxedOneHotCategorical(probs=probs,
                                       temperature=torch.tensor([1.]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size, 1)
        check_nan(logprob_cluster)

        z = cluster_S

        n_evals = 40
        x1 = np.linspace(-9, 205, n_evals)
        x = torch.from_numpy(x1).view(n_evals, 1).float().cuda()
        z = z.repeat(n_evals, 1)
        cluster_H = cluster_H.repeat(n_evals, 1)
        xz = torch.cat([z, x], dim=1)

        logpxz = logprob_undercomponent(
            x,
            component=cluster_H,
            needsoftmax_mixtureweight=needsoftmax_mixtureweight,
            cuda=True)
        f = logpxz  #- logprob_cluster

        surr_pred = surrogate.net(xz)
        surr_pred = surr_pred.data.cpu().numpy()
        f = f.data.cpu().numpy()

        col = 0
        row = i
        # print (row)
        ax = plt.subplot2grid((rows, cols), (row, col),
                              frameon=False,
                              colspan=1,
                              rowspan=1)

        ax.plot(x1, surr_pred, label='Surr')
        ax.plot(x1, f, label='f')
        ax.set_title(str(cluster_H[0]))
        ax.legend()

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir + 'gmm_surr.png'
    plt.savefig(plt_path)
    print('saved training plot', plt_path)
    plt.close()
Example #4
0
def simplax(surrogate, x, logits, mixtureweights, k=1):

    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())

    outputs = {}
    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq = cat.log_prob(cluster_S.detach()).view(B,1)
        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq - 1.

        surr_input = torch.cat([cluster_S, x, logits], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)

        net_loss += - torch.mean((f.detach() - surr_pred.detach()) * logq  + surr_pred)


        # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred))
        # grad_logq =  torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        grad_surr =  torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq + grad_surr)**2)

        surr_dif = torch.mean(torch.abs(f.detach() - surr_pred))
        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))
   
    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    outputs['net_loss'] = net_loss
    outputs['f'] = f
    outputs['logpx_given_z'] = logpx_given_z
    outputs['logpz'] = logpz
    outputs['logq'] = logq
    outputs['surr_loss'] = surr_loss
    outputs['surr_dif'] = surr_dif   
    outputs['grad_path'] = grad_path   
    outputs['grad_score'] = grad_score   

    return outputs #net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
Example #5
0
def simplax(surrogate, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs,
                                   temperature=torch.tensor([1.]).cuda())

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq = cat.log_prob(cluster_S.detach()).view(B, 1)
        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B, 1)
        logpxz = logpx_given_z + logpz  #[B,1]
        f = logpxz - logq - 1.

        surr_input = torch.cat([cluster_S, x], dim=1)  #[B,21]
        surr_pred = surrogate.net(surr_input)

        net_loss += -torch.mean((f.detach() - surr_pred.detach()) * logq +
                                surr_pred)

        # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred))
        # grad_logq =  torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_logq = torch.autograd.grad([torch.mean(logq)], [logits],
                                        create_graph=True,
                                        retain_graph=True)[0]
        grad_surr = torch.autograd.grad([torch.mean(surr_pred)], [logits],
                                        create_graph=True,
                                        retain_graph=True)[0]
        surr_loss += torch.mean(
            ((f.detach() - surr_pred) * grad_logq + grad_surr)**2)

        surr_dif = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits],
                                        create_graph=True,
                                        retain_graph=True)[0]
        grad_score = torch.autograd.grad(
            [torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits],
            create_graph=True,
            retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))

    net_loss = net_loss / k
    surr_loss = surr_loss / k

    return net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
Example #6
0
def show_surr_preds():

    batch_size = 1

    rows = 3
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    for i in range(rows):

        x = sample_true(1).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        z = cluster_S

        n_evals = 40
        x1 = np.linspace(-9,205, n_evals)
        x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
        z = z.repeat(n_evals,1)
        cluster_H = cluster_H.repeat(n_evals,1)
        xz = torch.cat([z,x], dim=1) 

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_pred = surrogate.net(xz)
        surr_pred = surr_pred.data.cpu().numpy()
        f = f.data.cpu().numpy()

        col =0
        row = i
        # print (row)
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

        ax.plot(x1,surr_pred, label='Surr')
        ax.plot(x1,f, label='f')
        ax.set_title(str(cluster_H[0]))
        ax.legend()


    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_surr.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
Example #7
0
def reinforce_pz(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())

    net_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logq = cat.log_prob(cluster_S.detach()).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq - 1.
        net_loss += - torch.mean((f.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
Example #8
0
def reinforce_pz(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs,
                                   temperature=torch.tensor([1.]).cuda())

    net_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logq = cat.log_prob(cluster_S.detach()).view(B, 1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B, 1)
        logpxz = logpx_given_z + logpz  #[B,1]
        f = logpxz - logq - 1.
        net_loss += -torch.mean((f.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss / k

    return net_loss, f, logpx_given_z, logpz, logq
Example #9
0
    def get_loss():

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)
        
        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(f.detach()-surr_pred))

        return surr_loss
Example #10
0
        logits = encoder.net(x)
        probs = torch.softmax(logits/100., dim=1)
        # print (probs)
        # fsdafsa
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())

        net_loss = 0
        loss = 0
        surr_loss = 0
        for jj in range(k):

            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.
            logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            check_nan(logprob_cluster)

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz #- logprob_cluster
            # print (f)

            surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
            surr_pred = surrogate.net(surr_input)
            
            # print (f.shape)
            # print (surr_pred.shape)
            # print (logprob_cluster.shape)
            # fsadfsa
            surr_loss_1 = torch.mean(torch.abs(f.detach()-surr_pred))
            # if surr_loss_1 > 1.:
cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.]))

val = 1.
val2 = 0
val3 = 0
cmap = 'Blues'
alpha = 1.
xlimits = [val3, val]
ylimits = [val2, val]
numticks = 51
x = np.linspace(*xlimits, num=numticks)
y = np.linspace(*ylimits, num=numticks)
X, Y = np.meshgrid(x, y)
aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
aaa = torch.tensor(aaa).float()
logprob = cat.log_prob(aaa)

logprob = to_print2(logprob)
logprob = logprob.reshape(X.shape)
prob = np.exp(logprob)
cs = plt.contourf(X, Y, prob, cmap=cmap, alpha=alpha)

# ax.text(0,0,str(to_print2(weights)))

samps0 = []
samps1 = []
samps = []
for i in range(300):
    samp = cat.sample()
    samp = to_print2(samp)
    if samp[0] > samp[1]:
train_ = 1
n_steps = 1000  #0 #0 #1000 #50000 #
B = 1  #32 #0
k = 3
if train_:
    optim = torch.optim.Adam(surrogate.parameters(),
                             lr=1e-4,
                             weight_decay=1e-7)
    #Train surrogate
    for i in range(n_steps + 1):
        warmup = 1.

        cat = RelaxedOneHotCategorical(logits=logits.repeat(B, 1),
                                       temperature=torch.tensor([1.]))
        z = cat.rsample()
        logprob = cat.log_prob(z.detach()).view(B, 1)
        b = H(z)
        reward = f(b).view(B, 1)
        cz = surrogate.net(z)

        # estimator = (reward - cz) * logprob + cz
        # grad = torch.autograd.grad([torch.mean(estimator)], [logits], create_graph=True, retain_graph=True)[0]

        gradlogprob = torch.autograd.grad([torch.mean(logprob)], [logits],
                                          create_graph=True,
                                          retain_graph=True)[0]
        gradcz = torch.autograd.grad([torch.mean(cz)], [logits],
                                     create_graph=True,
                                     retain_graph=True)[0]

        # print (reward.shape, cz.shape, gradlogprob.shape, gradcz.shape)
Example #13
0

            cluster_H = H(cluster_S)
            
            comp = cluster_H.cpu().numpy()[0]
            # print (comp)
            counts[comp] +=1
            # fsfsa
            # print (x, 'samp', cluster_H)
            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.




            logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            check_nan(logprob_cluster)

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz #- logprob_cluster
            # print (jj%20, f, logprob_cluster)
            print ()
            print ('H:', comp, 'f', f, )
            print ('logq', logprob_cluster)

            surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
            surr_pred = surrogate.net(surr_input)
            
            # print (f.shape)
            # print (surr_pred.shape)
            # print (logprob_cluster.shape)
cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.]))

val = 1.
val2 = 0
val3 = 0
cmap='Blues'
alpha =1.
xlimits=[val3, val]
ylimits=[val2, val]
numticks = 51
x = np.linspace(*xlimits, num=numticks)
y = np.linspace(*ylimits, num=numticks)
X, Y = np.meshgrid(x, y)
aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
aaa = torch.tensor(aaa).float()
logprob = cat.log_prob(aaa)

logprob = to_print2(logprob)
logprob = logprob.reshape(X.shape)
prob = np.exp(logprob)
cs = plt.contourf(X, Y, prob, cmap=cmap, alpha=alpha)

# ax.text(0,0,str(to_print2(weights)))

samps0 = []
samps1 = []
samps = []
for i in range (300):
    samp = cat.sample()
    samp =to_print2(samp)
    if samp[0]>samp[1]:
surrogate = NN3(input_size=C, output_size=1, n_residual_blocks=2)

train_ = 1
n_steps = 1000#0 #0 #1000 #50000 #
B = 1 #32 #0
k=3
if train_:
    optim = torch.optim.Adam(surrogate.parameters(), lr=1e-4, weight_decay=1e-7)
    #Train surrogate
    for i in range(n_steps+1):
        warmup = 1.

        cat = RelaxedOneHotCategorical(logits=logits.repeat(B,1), temperature=torch.tensor([1.]))
        z = cat.rsample()
        logprob = cat.log_prob(z.detach()).view(B,1)
        b = H(z)
        reward = f(b).view(B,1) 
        cz = surrogate.net(z)

        # estimator = (reward - cz) * logprob + cz
        # grad = torch.autograd.grad([torch.mean(estimator)], [logits], create_graph=True, retain_graph=True)[0]

        gradlogprob =  torch.autograd.grad([torch.mean(logprob)], [logits], create_graph=True, retain_graph=True)[0]
        gradcz =  torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]

        # print (reward.shape, cz.shape, gradlogprob.shape, gradcz.shape)
        # fdasf
        # grad = (reward-cz) *gradlogprob + gradcz

        loss = torch.mean(((reward.view(B,1) - cz) * gradlogprob.repeat(B,1) +  (gradcz.repeat(B,1)  )**2))
 def sample_simplax(probs):
     dist = RelaxedOneHotCategorical(probs=probs, temperature=torch.Tensor([1.]))
     z = dist.rsample()
     logprob = dist.log_prob(z)
     b = torch.argmax(z, dim=1)
     return z, b, logprob
Example #17
0
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
    cat_bernoulli = Categorical(probs=probs)

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq_z = cat.log_prob(cluster_S.detach()).view(B,1)
        logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1)


        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f_z = logpxz - logq_z - 1.
        f_b = logpxz - logq_b - 1.

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        # surr_pred, alpha = surrogate.net(surr_input)
        surr_pred = surrogate.net(surr_input)
        alpha = torch.sigmoid(surrogate2.net(x))

        net_loss += - torch.mean(     alpha.detach()*(f_z.detach()  - surr_pred.detach()) * logq_z  
                                    + alpha.detach()*surr_pred 
                                    + (1-alpha.detach())*(f_b.detach()  ) * logq_b)

        # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred))

        grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_logq_b =  torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape)
        # fsdfa
        # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0]
        # print (grad_surr)
        # fsdfasd
        surr_loss += torch.mean(
                                    (alpha*(f_z.detach() - surr_pred) * grad_logq_z 
                                    + alpha*grad_surr
                                    + (1-alpha)*(f_b.detach()) * grad_logq_b )**2
                                    )

        surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred))
        # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0]
        # print (gradd)
        # fdsf
        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))


    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
weights = torch.softmax(needsoftmax_mixtureweight, dim=0)



# cat = Categorical(probs= weights)
avg_samp = torch.zeros(n_cats)
grads = []
logprobgrads = []
momem = 0
for step in range(max_steps):

    weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
    cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.]))
    cluster_S = cat.sample()
    logprob = cat.log_prob(cluster_S.detach())
    cluster_H = H(cluster_S) #

    logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(needsoftmax_mixtureweight), retain_graph=False)[0]

    one_hot = torch.zeros(n_cats)
    one_hot[cluster_H] = 1.

    f_val = f(one_hot)

    # grad = f_val * logprobgrad
    # needsoftmax_mixtureweight = needsoftmax_mixtureweight + lr*grad

    grad = f_val * logprobgrad
    momem += grad*.1
    needsoftmax_mixtureweight = needsoftmax_mixtureweight+ momem*lr
Example #19
0
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs,
                                   temperature=torch.tensor([1.]).cuda())
    cat_bernoulli = Categorical(probs=probs)

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq_z = cat.log_prob(cluster_S.detach()).view(B, 1)
        logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B, 1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B, 1)
        logpxz = logpx_given_z + logpz  #[B,1]
        f_z = logpxz - logq_z - 1.
        f_b = logpxz - logq_b - 1.

        surr_input = torch.cat([cluster_S, x], dim=1)  #[B,21]
        # surr_pred, alpha = surrogate.net(surr_input)
        surr_pred = surrogate.net(surr_input)
        alpha = torch.sigmoid(surrogate2.net(x))

        net_loss += -torch.mean(alpha.detach() *
                                (f_z.detach() - surr_pred.detach()) * logq_z +
                                alpha.detach() * surr_pred +
                                (1 - alpha.detach()) * (f_b.detach()) * logq_b)

        # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred))

        grad_logq_z = torch.mean(torch.autograd.grad([torch.mean(logq_z)],
                                                     [logits],
                                                     create_graph=True,
                                                     retain_graph=True)[0],
                                 dim=1,
                                 keepdim=True)
        grad_logq_b = torch.mean(torch.autograd.grad([torch.mean(logq_b)],
                                                     [logits],
                                                     create_graph=True,
                                                     retain_graph=True)[0],
                                 dim=1,
                                 keepdim=True)
        grad_surr = torch.mean(torch.autograd.grad([torch.mean(surr_pred)],
                                                   [logits],
                                                   create_graph=True,
                                                   retain_graph=True)[0],
                               dim=1,
                               keepdim=True)
        # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape)
        # fsdfa
        # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0]
        # print (grad_surr)
        # fsdfasd
        surr_loss += torch.mean(
            (alpha * (f_z.detach() - surr_pred) * grad_logq_z +
             alpha * grad_surr + (1 - alpha) *
             (f_b.detach()) * grad_logq_b)**2)

        surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred))
        # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0]
        # print (gradd)
        # fdsf
        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits],
                                        create_graph=True,
                                        retain_graph=True)[0]
        grad_score = torch.autograd.grad(
            [torch.mean(
                (f_z.detach() - surr_pred.detach()) * logq_z)], [logits],
            create_graph=True,
            retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))

    net_loss = net_loss / k
    surr_loss = surr_loss / k

    return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(
        alpha)
Example #20
0
        probs = torch.softmax(logits / 100., dim=1)
        # print (probs)
        # fsdafsa
        cat = RelaxedOneHotCategorical(probs=probs,
                                       temperature=torch.tensor([temp]).cuda())

        net_loss = 0
        loss = 0
        surr_loss = 0
        for jj in range(k):

            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.
            logprob_cluster = cat.log_prob(cluster_S.detach()).view(
                batch_size, 1)
            check_nan(logprob_cluster)

            logpxz = logprob_undercomponent(
                x,
                component=cluster_H,
                needsoftmax_mixtureweight=needsoftmax_mixtureweight,
                cuda=True)
            f = logpxz  #- logprob_cluster
            # print (f)

            surr_input = torch.cat([cluster_S, x], dim=1)  #[B,21]
            surr_pred = surrogate.net(surr_input)

            # print (f.shape)
            # print (surr_pred.shape)
needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

# cat = Categorical(probs= weights)
avg_samp = torch.zeros(n_cats)
grads = []
logprobgrads = []
momem = 0
for step in range(max_steps):

    weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
    cat = RelaxedOneHotCategorical(probs=weights,
                                   temperature=torch.tensor([1.]))
    cluster_S = cat.sample()
    logprob = cat.log_prob(cluster_S.detach())
    cluster_H = H(cluster_S)  #

    logprobgrad = torch.autograd.grad(outputs=logprob,
                                      inputs=(needsoftmax_mixtureweight),
                                      retain_graph=False)[0]

    one_hot = torch.zeros(n_cats)
    one_hot[cluster_H] = 1.

    f_val = f(one_hot)

    # grad = f_val * logprobgrad
    # needsoftmax_mixtureweight = needsoftmax_mixtureweight + lr*grad

    grad = f_val * logprobgrad