def sample_relax_given_class(logits, samp):

    cat = Categorical(logits=logits)

    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels

    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)


    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)


    z = z_tilde

    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

    return z, z_tilde, logprob
def sample_relax(logits): #, k=1):
    

    # u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
    u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)

    cat = Categorical(logits=logits)
    logprob = cat.log_prob(b).view(B,1)

    v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
    z_tilde_b = -torch.log(-torch.log(v_k))
    #this way seems biased even tho it shoudlnt be
    # v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
    # z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))

    v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    probs = torch.softmax(logits,dim=1).repeat(B,1)
    # print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
    # fasdfa

    # print (v.shape)
    # print (v.shape)
    z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))

    # print (z_tilde)
    # print (z_tilde_b)
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
    # print (z_tilde)
    # fasdfs

    return z, b, logprob, z_tilde
Esempio n. 3
0
 def updateOutput(self, input):
     if self.mininput is None:
         self.mininput = input.new()
     self.mininput.resize_as_(input).copy_(input).mul_(-1)
     self.output = torch.softmax(
         self.mininput,
         self._get_dim(input)
     )
     return self.output
Esempio n. 4
0
def simplax(surrogate, x, logits, mixtureweights, k=1):

    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())

    outputs = {}
    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq = cat.log_prob(cluster_S.detach()).view(B,1)
        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq - 1.

        surr_input = torch.cat([cluster_S, x, logits], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)

        net_loss += - torch.mean((f.detach() - surr_pred.detach()) * logq  + surr_pred)


        # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred))
        # grad_logq =  torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        grad_surr =  torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq + grad_surr)**2)

        surr_dif = torch.mean(torch.abs(f.detach() - surr_pred))
        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))
   
    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    outputs['net_loss'] = net_loss
    outputs['f'] = f
    outputs['logpx_given_z'] = logpx_given_z
    outputs['logpz'] = logpz
    outputs['logq'] = logq
    outputs['surr_loss'] = surr_loss
    outputs['surr_dif'] = surr_dif   
    outputs['grad_path'] = grad_path   
    outputs['grad_score'] = grad_score   

    return outputs #net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
Esempio n. 5
0
    def sample_relax_given_b(logits, b):

        u_b = torch.rand(B,1).clamp(1e-10, 1.-1e-10).cuda()
        z_tilde_b = -torch.log(-torch.log(u_b))

        u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits,dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        return z_tilde
Esempio n. 6
0
def logprob_givenmixtureeweights(x, needsoftmax_mixtureweight):

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    probs_sum = 0# = []
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        # for x in xs:
        component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
        # probs.append(probs)
        probs_sum+=component_i
    logprob = torch.log(probs_sum)
    return logprob
def sample_relax_given_class_k(logits, samp, k):

    cat = Categorical(logits=logits)
    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)

    zs = []
    z_tildes = []
    for i in range(k):

        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        z = z_tilde

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs= torch.stack(zs)
    z_tildes= torch.stack(z_tildes)
    
    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob
Esempio n. 8
0
def show_surr_preds():

    batch_size = 1

    rows = 3
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    for i in range(rows):

        x = sample_true(1).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        z = cluster_S

        n_evals = 40
        x1 = np.linspace(-9,205, n_evals)
        x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
        z = z.repeat(n_evals,1)
        cluster_H = cluster_H.repeat(n_evals,1)
        xz = torch.cat([z,x], dim=1) 

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_pred = surrogate.net(xz)
        surr_pred = surr_pred.data.cpu().numpy()
        f = f.data.cpu().numpy()

        col =0
        row = i
        # print (row)
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

        ax.plot(x1,surr_pred, label='Surr')
        ax.plot(x1,f, label='f')
        ax.set_title(str(cluster_H[0]))
        ax.legend()


    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_surr.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
Esempio n. 9
0
def plot_dist(x=None):

    if x is None:
        x1 = sample_true(1).cuda() 
    else:
        x1 = x[0].cpu().numpy()#.view(1,1)
        # print (x)

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

    rows = 1
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    col =0
    row = 0
    ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


    xs = np.linspace(-9,205, 300)
    sum_ = np.zeros(len(xs))

    C = 20
    for c in range(C):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        ys = []
        for x in xs:
            # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
            component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()


            ys.append(component_i)

        ys = np.reshape(np.array(ys), [-1])
        sum_ += ys
        ax.plot(xs, ys, label='')

    ax.plot(xs, sum_, label='')

    # print (x)
    ax.plot([x1,x1+.001],[0.,.002])
    # fasda

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_plot_dist.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
Esempio n. 10
0
def true_posterior(x, needsoftmax_mixtureweight):

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    probs_ = []
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float().cuda())
        component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
        # print(component_i.shape)
        # fsdf
        probs_.append(component_i[0])
    probs_ = torch.stack(probs_)
    probs_ = probs_ / torch.sum(probs_)
    # print (probs_.shape)
    # fdssdfd
    # logprob = torch.log(probs_sum)
    return probs_
Esempio n. 11
0
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)
    outputs = {}

    cat = Categorical(probs=probs)

    grads =[]
    # net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
        outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]

        surr_pred = surrogate.net(x)

        outputs['f'] = f = logpxz - logq - 1. 
        # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
        outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)

        if get_grad:
            grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
            grads.append(grad)

    # net_loss = net_loss/ k

    if get_grad:
        grads = torch.stack(grads)
        # print (grads.shape)
        outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
        outputs['grad_std'] = torch.std(grads, dim=0)[0]

    outputs['surr_loss'] = surr_loss
    # return net_loss, f, logpx_given_z, logpz, logq
    return outputs
Esempio n. 12
0
def logprob_undercomponent(x, component, needsoftmax_mixtureweight, cuda=False):
    # c= component
    # C = c.
    B = x.shape[0]
    # print()
    # print (needsoftmax_mixtureweight.shape)
    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    # print (mixture_weights.shape)
    # fdsfa
    # probs_sum = 0# = []
    # for c in range(n_components):
    # m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float() )#.cuda())
    mean = (component.float()*10.).view(B,1)
    std = (torch.ones([B]) *5.).view(B,1)
    # print (mean.shape) #[B]
    if not cuda:
        m = Normal(mean, std)#.cuda())
    else:
        m = Normal(mean.cuda(), std.cuda())
    # for x in xs:
    # component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
    # print (m.log_prob(x))
    # print (torch.log(mixture_weights[c]))
    # print(x.shape)
    logpx_given_z = m.log_prob(x)
    logpz = torch.log(mixture_weights[component]).view(B,1)
    # print (px_given_z.shape)
    # print (component)
    # print (mixture_weights)
    # print (mixture_weights[component])
    # print (torch.log(mixture_weights[component]).shape)
    # fdsasa
    # print (logpx_given_z.shape)
    # print (logpz.shape)
    # fsdfas
    logprob = logpx_given_z + logpz
    # print (logprob.shape)
    # fsfd
    # probs.append(probs)
    # probs_sum+=component_i
    # logprob = torch.log(component_i)
    return logprob
Esempio n. 13
0
def inference_error():

    x = sample_true(1).cuda() 
    trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)



    logits = encoder.net(x)
    probs = torch.softmax(logits, dim=1).view(n_components)


    # print(trueposterior)
    # print (probs)
    # print ((trueposterior-probs)**2)
    # print()

    # print (trueposterior.shape)
    # print (probs.shape)
    # print (L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy()))
    return L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
Esempio n. 14
0
def logprob_undercomponent(x, component, needsoftmax_mixtureweight, cuda=False):
    c= component
    # print (needsoftmax_mixtureweight.shape)
    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    # probs_sum = 0# = []
    # for c in range(n_components):
    # m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float() )#.cuda())
    if not cuda:
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float() )#.cuda())
    else:
        m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float().cuda())
    # for x in xs:
    # component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
    # print (m.log_prob(x))
    # print (torch.log(mixture_weights[c]))

    logprob = m.log_prob(x) + torch.log(mixture_weights[c])
    # probs.append(probs)
    # probs_sum+=component_i
    # logprob = torch.log(component_i)
    return logprob
Esempio n. 15
0
def inference_error():

    error_sum = 0
    kl_sum = 0
    n=10
    for i in range(n):

        # if x is None:
        x = sample_true(1).cuda() 
        trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

        logits = encoder.net(x)
        probs = torch.softmax(logits/100., dim=1).view(n_components)

        error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
        kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy())

        error_sum+=error
        kl_sum += kl
    
    return error_sum/n, kl_sum/n
Esempio n. 16
0
def reinforce(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = Categorical(probs=probs)

    net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        logq = cat.log_prob(cluster_H).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq
        net_loss += - torch.mean((f.detach() - 1.) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
Esempio n. 17
0
    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy)
        mul_attn = torch.mul(attn, p_copy)
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)
        ).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)
Esempio n. 18
0
def reinforce_pz(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())

    net_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logq = cat.log_prob(cluster_S.detach()).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq - 1.
        net_loss += - torch.mean((f.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
Esempio n. 19
0
def plot_posteriors(x=None, name=''):

    if x is None:
        x = sample_true(1).cuda() 
    else:
        x = x[0].view(1,1)


    trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

    logits = encoder.net(x)
    probs = torch.softmax(logits, dim=1).view(n_components)

    trueposterior = trueposterior.data.cpu().numpy()
    qz = probs.data.cpu().numpy()

    rows = 1
    cols = 1
    fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150)

    col =0
    row = 0
    ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

    width = .3
    ax.bar(range(len(qz)), trueposterior, width=width, label='True')
    ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q')
    # ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width)
    ax.legend()
    ax.grid(True, alpha=.3)
    ax.set_title(str(x))

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'posteriors' + name+'.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
Esempio n. 20
0
    def get_loss():

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)
        
        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(f.detach()-surr_pred))

        return surr_loss
Esempio n. 21
0
    def forward(self, input: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        """
        if self.sigmoid:
            input = torch.sigmoid(input)

        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        assert (
            target.shape == input.shape
        ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

        p0 = input
        p1 = 1 - p0
        g0 = target
        g1 = 1 - g0

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))
        if self.batch:
            # reducing spatial dimensions and batch
            reduce_axis = [0] + reduce_axis

        tp = torch.sum(p0 * g0, reduce_axis)
        fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
        fn = self.beta * torch.sum(p1 * g0, reduce_axis)
        numerator = tp + self.smooth_nr
        denominator = tp + fp + fn + self.smooth_dr

        score: torch.Tensor = 1.0 - numerator / denominator

        if self.reduction == LossReduction.SUM.value:
            return torch.sum(score)  # sum over the batch and channel dims
        if self.reduction == LossReduction.NONE.value:
            return score  # returns [N, n_classes] losses
        if self.reduction == LossReduction.MEAN.value:
            return torch.mean(score)
        raise ValueError(
            f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
        )
Esempio n. 22
0
 def forward(self, x):
     # x: (n_samples, n_in, n_time)
     norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
     cla = self.nonlinear_transform(self.cla(x))
     x = torch.sum(norm_att * cla, dim=2)
     return x, norm_att, cla
Esempio n. 23
0
def train(method, n_components, true_mixture_weights, exp_dir, needsoftmax_mixtureweight=None):

    print('Method:', method)

    true_mixture_weights = torch.tensor(true_mixture_weights, 
                                            requires_grad=True, device="cuda")

    if needsoftmax_mixtureweight is None:
        needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")
    else:
        needsoftmax_mixtureweight = torch.tensor(needsoftmax_mixtureweight, 
                                            requires_grad=True, device="cuda")
    
    optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=1e-5, weight_decay=1e-7)

    encoder = NN3(input_size=1, output_size=n_components, n_residual_blocks=3).cuda()
    optim_net = torch.optim.Adam(encoder.parameters(), lr=1e-4, weight_decay=1e-7)

    if method in ['simplax', 'relax']:
        surrogate = NN3(input_size=1+n_components, output_size=1, n_residual_blocks=4).cuda()
        optim_surr = torch.optim.Adam(surrogate.parameters(), lr=1e-3)

    if method == 'HLAX':
        # surrogate = NN4(input_size=1+n_components, output_size=1, n_residual_blocks=4).cuda()
        surrogate = NN3(input_size=1+n_components, output_size=1, n_residual_blocks=4).cuda()
        surrogate2 = NN3(input_size=1, output_size=1, n_residual_blocks=2).cuda()
        optim_surr = torch.optim.Adam(surrogate.parameters(), lr=1e-3)
        optim_surr2 = torch.optim.Adam(surrogate2.parameters(), lr=1e-3)

    data_dict = {}
    data_dict['steps'] = []
    data_dict['theta_losses'] = []
    data_dict['f'] = []
    data_dict['lpx_given_z'] = []
    data_dict['lpz'] = []
    data_dict['lqz'] = []
    data_dict['inference_L2'] = []
    # data_dict['grad_var'] = []
    data_dict['grad_avg'] = []

    if method in ['simplax', 'HLAX', 'relax']:
        data_dict['surr_loss'] = []
        data_dict['surr_dif'] = []
    if method in ['simplax', 'HLAX']:
        data_dict['grad_split'] = {}
        data_dict['grad_split']['score'] = []
        data_dict['grad_split']['path'] = []
    if method=='HLAX':
        data_dict['alpha'] = []

    for step in range(0,n_steps+1):

        mixtureweights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() #[C]

        x = sample_gmm(batch_size, mixture_weights=true_mixture_weights)
        logits = encoder.net(x)

        if method == 'reinforce':
            net_loss, f, logpx_given_z, logpz, logq = reinforce(x, logits, mixtureweights, k=1)
        elif method == 'reinforce_pz':
            net_loss, f, logpx_given_z, logpz, logq = reinforce_pz(x, logits, mixtureweights, k=1)
        elif method == 'simplax':
            net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score = simplax(surrogate, x, logits, mixtureweights, k=1)
        elif method == 'HLAX':
            net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score, alpha = HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1)
        elif method=='relax':
            net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif = relax(surrogate, x, logits, mixtureweights, k=1)


        #Grad variance 
        grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
        # grad_var = torch.mean(torch.std(grad, dim=0))
        grad_avg = torch.mean(torch.abs(grad))


        # Update encoder
        optim_net.zero_grad()
        net_loss.backward(retain_graph=True)
        optim_net.step()

        # Update generator
        loss = - torch.mean(f)
        optim.zero_grad()
        loss.backward(retain_graph=True)  
        optim.step()

        # Update surrogate
        if method in ['simplax', 'HLAX', 'relax']:
            optim_surr.zero_grad()
            surr_loss.backward(retain_graph=True)
            optim_surr.step()

        if method == 'HLAX':
            optim_surr2.zero_grad()
            surr_loss.backward(retain_graph=True)
            optim_surr2.step()

        if step%print_steps==0:
            # print (step, to_print(net_loss), to_print(logpxz - logq), to_print(logpx_given_z), to_print(logpz), to_print(logq))

            current_theta = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
            theta_loss = L2_error(to_print2(true_mixture_weights),
                                    to_print2(current_theta))

            pz_give_x = true_posterior(x, mixture_weights=mixtureweights)
            probs = torch.softmax(logits, dim=1)
            inference_L2 = L2_batch(pz_give_x, probs)


            print( 
                'S:{:5d}'.format(step),
                'Theta_loss:{:.3f}'.format(theta_loss),
                'Loss:{:.3f}'.format(to_print1(net_loss)),
                'ELBO:{:.3f}'.format(to_print1(f)),
                'lpx|z:{:.3f}'.format(to_print1(logpx_given_z)),
                'lpz:{:.3f}'.format(to_print1(logpz)),
                'lqz:{:.3f}'.format(to_print1(logq)),
                )

            if step> 0:
                data_dict['steps'].append(step)
                data_dict['theta_losses'].append(theta_loss)
                data_dict['f'].append(to_print1(f))
                data_dict['lpx_given_z'].append(to_print1(logpx_given_z))
                data_dict['lpz'].append(to_print1(logpz))
                data_dict['lqz'].append(to_print1(logq))
                data_dict['inference_L2'].append(to_print2(inference_L2))
                # data_dict['grad_var'].append(to_print2(grad_var))
                data_dict['grad_avg'].append(to_print2(grad_avg))
                if method in ['simplax', 'HLAX', 'relax']:
                    data_dict['surr_loss'].append(to_print2(surr_loss))
                    data_dict['surr_dif'].append(to_print2(surr_dif))
                if method in ['simplax', 'HLAX']:
                    data_dict['grad_split']['score'].append(to_print2(grad_score))
                    data_dict['grad_split']['path'].append(to_print2(grad_path))
                if method == 'HLAX':
                    data_dict['alpha'].append(to_print2(alpha))

            check_nan(net_loss)



        if step%plot_steps==0 and step!=0:

            plot_curve2(data_dict, exp_dir)

            # list_of_posteriors = []
            # for ii in range(5):
            #     pz_give_x = true_posterior(x, mixture_weights=mixtureweights)
            #     probs = torch.softmax(logits, dim=1)
            #     inference_L2 = L2_batch(pz_give_x, probs)    
            #     list_of_posteriors.append([to_print2(pz_give_x), to_print2(probs)])      

            if step% (plot_steps*2) ==0:
            # if step% plot_steps ==0:
                plot_posteriors2(n_components, trueposteriors=to_print2(pz_give_x), qs=to_print2(probs), exp_dir=exp_dir, name=str(step))
            
                plot_dist2(n_components, mixture_weights=to_print2(current_theta), true_mixture_weights=to_print2(true_mixture_weights), exp_dir=exp_dir, name=str(step))

        if step % params_step==0 and step>0:

            # save_dir = home+'/Documents/Grad_Estimators/GMM/'
            with open( exp_dir+"data.p", "wb" ) as f:
                pickle.dump(data_dict, f)
            print ('saved data')
    def forward(self, outputs, targets):
        loss_cls = 0
        loss_reg = 0
        loss_branch = []
        for i in range(self.num_output_scales):
            pred_score = outputs[i * 2]
            pred_bbox = outputs[i * 2 + 1]
            gt_mask = targets[i * 2].cuda()
            gt_label = targets[i * 2 + 1].cuda()

            pred_score_softmax = torch.softmax(pred_score, dim=1)
            # loss_mask = torch.ones(pred_score_softmax.shape[0],
            #                        1,
            #                        pred_score_softmax.shape[2],
            #                        pred_score_softmax.shape[3])
            loss_mask = torch.ones(pred_score_softmax.shape)

            if self.hnm_ratio > 0:
                # print('gt_label.shape:', gt_label.shape)
                # print('gt_label.size():', gt_label.size())
                pos_flag = (gt_label[:, 0, :, :] > 0.5)
                pos_num = torch.sum(pos_flag) # get num. of positive examples

                if pos_num > 0:
                    neg_flag = (gt_label[:, 1, :, :] > 0.5)
                    neg_num = torch.sum(neg_flag)
                    neg_num_selected = min(int(self.hnm_ratio * pos_num), int(neg_num))
                    # non-negative value
                    neg_prob = torch.where(neg_flag, pred_score_softmax[:, 1, :, :], \
                                           torch.zeros_like(pred_score_softmax[:, 1, :, :]))
                    neg_prob_sort, _ = torch.sort(neg_prob.reshape(1, -1), descending=False)

                    prob_threshold = neg_prob_sort[0][neg_num_selected-1]
                    neg_grad_flag = (neg_prob <= prob_threshold)
                    loss_mask = torch.cat([pos_flag.unsqueeze(1), neg_grad_flag.unsqueeze(1)], dim=1)
                else:
                    neg_choice_ratio = 0.1
                    neg_num_selected = int(pred_score_softmax[:, 1, :, :].numel() * neg_choice_ratio)
                    neg_prob = pred_score_softmax[:, 1, :, :]
                    neg_prob_sort, _ = torch.sort(neg_prob.reshape(1, -1), descending=False)
                    prob_threshold = neg_prob_sort[0][neg_num_selected-1]
                    neg_grad_flag = (neg_prob <= prob_threshold)
                    loss_mask = torch.cat([pos_flag.unsqueeze(1), neg_grad_flag.unsqueeze(1)], dim=1)

            # cross entropy with mask
            pred_score_softmax_masked = pred_score_softmax[loss_mask]
            pred_score_log = torch.log(pred_score_softmax_masked)
            score_cross_entropy = -gt_label[:, :2, :, :][loss_mask] * pred_score_log
            loss_score = torch.sum(score_cross_entropy) / score_cross_entropy.numel()

            mask_bbox = gt_mask[:, 2:6, :, :]
            if torch.sum(mask_bbox) == 0:
                loss_bbox = torch.zeros_like(loss_score)
            else:
                predict_bbox = pred_bbox * mask_bbox
                label_bbox = gt_label[:, 2:6, :, :] * mask_bbox
                loss_bbox = F.mse_loss(predict_bbox, label_bbox, reduction='sum') / torch.sum(mask_bbox)
                # loss_bbox = F.smooth_l1_loss(predict_bbox, label_bbox, reduction='sum') / torch.sum(mask_bbox)
                # loss_bbox = torch.nn.MSELoss(predict_bbox, label_bbox, size_average=False, reduce=True)
                # loss_bbox = torch.nn.SmoothL1Loss(predict_bbox, label_bbox, size_average=False, reduce=True)

            loss_cls += loss_score
            loss_reg += loss_bbox
            loss_branch.append(loss_score)
            loss_branch.append(loss_bbox)
        loss = loss_cls + loss_reg
        return loss, loss_branch
Esempio n. 25
0
def test(args, model, classifier, test_loader):

    # switch to evaluate mode
    model.eval()
    classifier.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    total_pred = []
    total_target = []
    total_pred_score = []

    with torch.no_grad():

        end = time.time()

        for batch_idx, (input, target) in enumerate(tqdm(test_loader, disable=False)):

            # Get inputs and target
            input, target = input.float(), target.long()

            # Move the variables to Cuda
            input, target = input.cuda(), target.cuda()

            # compute output ###############################
            feats = model(input)
            output = classifier(feats)
            pred_score = torch.softmax(output.detach_(), dim=-1)

            #######
            loss = F.cross_entropy(output, target, reduction='mean')

            # compute loss and accuracy
            batch_size = target.size(0)
            losses.update(loss.item(), batch_size)

            pred = torch.argmax(output, dim=1)
            acc.update(torch.sum(target == pred).item() / batch_size, batch_size)

            # Save pred, target to calculate metrics
            total_pred.append(pred)
            total_target.append(target)
            total_pred_score.append(pred_score)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # print statistics and write summary every N batch
            if (batch_idx + 1) % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                      'acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                    batch_idx, len(test_loader), batch_time=batch_time, loss=losses, acc=acc))

        # Pred and target for performance metrics
        final_predictions = torch.cat(total_pred).to('cpu')
        final_targets = torch.cat(total_target).to('cpu')
        final_pred_score = torch.cat(total_pred_score).to('cpu')

    return final_predictions, final_targets, final_pred_score
Esempio n. 26
0
    def forward(self, hidden, enc_output):
 
        hidden_with_time_axis = hidden.permute(1, 0, 2) 
        
        # print('enc_output.shape')
        # print(hidden.shape)   
        # print(enc_output.shape)   


        hidden = hidden.view(self.batch_sz, -1 ,self.directions * self.dec_units)

        # print('self.W1(enc_output).shape')
        # print(self.W1(enc_output).shape)  
        # print(self.W2(hidden).shape)

        q = self.W1(enc_output)#.view(self.batch_sz, enc_output.shape[1], self.heads, self.enc_units)

        # hidden = hidden.permute(1, 0, 2) 

        
        k = self.W2(hidden)#.view(self.batch_sz, -1, self.heads, self.enc_units)

        # print('q , k') 
        # print(q.shape) 
        # print(k.shape)
        
        # q = q.permute(2, 0, 1, 3).contiguous()#.view(-1, enc_output.shape[1], self.enc_units) # (n*b) x lq x dk

        # k = k.permute(2, 0, 1, 3).contiguous()#.view(-1, hidden.shape[1], self.enc_units) # (n*b) x lk x dk


        # print('q , k') 
        # print(q.shape) 
        # print(k.shape) 

        # if self.directions > 1:
        #   q = q.view(1, q.shape[0], enc_output.shape[1], self.enc_units * self.heads) 
        #   k = k.view(2, k.shape[0], 1, self.enc_units * self.heads)
        # # print('q , k') 
        # print(q.shape) 
        # print(k.shape)   
        K = torch.tanh(q + k)
        # except:
        #     print('q , k') 
        #     print(q.shape) 
        #     print(k.shape) 

        K = K#.view(self.batch_sz * self.heads, enc_output.shape[1], self.directions * self.enc_units)
 
        try:
            score = self.V(K)
        except:
            print('k.shape')
            print(K.shape)        # print('score.shape')
        # print(score.shape)

        attention_weights = torch.softmax(score, dim=1) # alpha


        # print('attention_weights') 
        # print(attention_weights.view(self.batch_sz,-1,42).shape) 
        # print('enc_output.shape') 
        # print(enc_output.shape) 
        # print(attention_weights.shape) 


        # print(enc_output.view(-1, enc_output.shape[0], enc_output.shape[1], self.enc_units).shape) 

        # torch.einsum('ijk,abk->abc', (attention_weights, enc_output))

        # context_vector = enc_output * attention_weights
        
        attention_weights = attention_weights.permute(0, 2, 1)

        context_vector = torch.bmm(attention_weights, enc_output)
        context_vector = context_vector.view(self.batch_sz,-1,self.enc_units)
        
        # context_vector = attention_weights.bmm(enc_output)

        # print('context_vector.shape')
        # print(context_vector.shape)

        context_vector = torch.sum(context_vector, dim=1) # a

        # print('context_vector.shape')
        # print(context_vector.shape)

        
        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        # # takes case of the right portion of the model above (illustrated in red)
        # x = self.embedding(x)
    
        # x = torch.cat((context_vector.unsqueeze(1), x), -1)
    
        x = self.fc(context_vector)
        
        return torch.sigmoid(x), attention_weights
C=3
N = 5000

# theta = .5
# bern_param = torch.tensor([theta], requires_grad=True).view(B,1)
# aa = 1 - bern_param
# probs = torch.cat([aa, bern_param], dim=1)
# logits = torch.log(probs)

# vs

# logits =  torch.ones((B,C), requires_grad=True) # this is what I had before, not sure how the B even worked * -0.6931

# logits =  torch.ones((1,C), requires_grad=True) #* -0.6931
# logits =  torch.zeros((1,C), requires_grad=True) #* -0.6931
probs = torch.softmax(torch.ones((1,C)), dim=1)
logits = torch.log(probs)
logits = torch.tensor(logits, requires_grad=True)


# probs = torch.ones((B,C), requires_grad=True) / C
# rewards = torch.tensor([-1., 0., 1., 4.])
rewards = torch.tensor([-1., 1., 2.]) #* 100.
# rewards = torch.tensor([-1., 2.])

true = np.array([-.5478, .1122, .4422])
# true = np.array([0,0])

def f(ind):
    return rewards[ind]
Esempio n. 28
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{
        **vars(args),
        **vars(state_dict['args'])
    })
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=1,
                                              collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(
                                                  test_dataset,
                                                  9999999,
                                                  args.batch_size,
                                                  1,
                                                  0,
                                                  shuffle=False,
                                                  seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(
        args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):

        # Create a beam search object or every input sentence in batch
        batch_size = sample['src_tokens'].shape[0]
        searches = [
            BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx)
            for i in range(batch_size)
        ]

        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'],
                                        sample['src_lengths'])
            #print("src_tokens:", type(sample['src_tokens']))
            outlen = len(tgt_dict.string(sample['src_tokens']))
            print("-words len:", outlen)
            #print("-size:", list(sample['src_tokens'].size()))
            #print("-content:", sample['src_tokens'])
            # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent?
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)

            # Compute the decoder output at the first time step
            decoder_out, _ = model.decoder(go_slice, encoder_out)

            #print("Decoder_out:", type(decoder_out)) # <class 'torch.Tensor'>
            #print("-size:", list(decoder_out.size()))
            print("-content:", decoder_out)

            lp_y = 1 / (((5 + outlen**a) / ((5 + 1)**a)))
            decoder_out = torch.mul(decoder_out, lp_y)
            print("-normalized:", decoder_out)

            # __QUESTION 2: Why do we keep one top candidate more than the beam size?
            log_probs, next_candidates = torch.topk(torch.log(
                torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1,
                                                    dim=-1)

        # Create number of beam_size beam search nodes for every input sentence
        for i in range(batch_size):
            for j in range(args.beam_size):
                best_candidate = next_candidates[i, :, j]
                backoff_candidate = next_candidates[i, :, j + 1]
                best_log_p = log_probs[i, :, j]
                backoff_log_p = log_probs[i, :, j + 1]
                next_word = torch.where(best_candidate == tgt_dict.unk_idx,
                                        backoff_candidate, best_candidate)
                log_p = torch.where(best_candidate == tgt_dict.unk_idx,
                                    backoff_log_p, best_log_p)
                log_p = log_p[-1]

                # Store the encoder_out information for the current input sentence and beam
                emb = encoder_out['src_embeddings'][:, i, :]
                lstm_out = encoder_out['src_out'][0][:, i, :]
                final_hidden = encoder_out['src_out'][1][:, i, :]
                final_cell = encoder_out['src_out'][2][:, i, :]
                try:
                    mask = encoder_out['src_mask'][i, :]
                except TypeError:
                    mask = None

                node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden,
                                      final_cell, mask,
                                      torch.cat(
                                          (go_slice[i], next_word)), log_p, 1)
                # __QUESTION 3: Why do we add the node with a negative score?
                searches[i].add(-node.eval(), node)

        # Start generating further tokens until max sentence length reached
        for _ in range(args.max_len - 1):

            # Get the current nodes to expand
            nodes = [n[1] for s in searches for n in s.get_current_beams()]
            if nodes == []:
                break  # All beams ended in EOS

            # Reconstruct prev_words, encoder_out from current beam search nodes
            prev_words = torch.stack([node.sequence for node in nodes])
            encoder_out["src_embeddings"] = torch.stack(
                [node.emb for node in nodes], dim=1)
            lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
            final_hidden = torch.stack([node.final_hidden for node in nodes],
                                       dim=1)
            final_cell = torch.stack([node.final_cell for node in nodes],
                                     dim=1)
            encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
            try:
                encoder_out["src_mask"] = torch.stack(
                    [node.mask for node in nodes], dim=0)
            except TypeError:
                encoder_out["src_mask"] = None

            with torch.no_grad():
                # Compute the decoder output by feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)

            # see __QUESTION 2
            log_probs, next_candidates = torch.topk(torch.log(
                torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1,
                                                    dim=-1)

            # Create number of beam_size next nodes for every current node
            for i in range(log_probs.shape[0]):
                for j in range(args.beam_size):

                    best_candidate = next_candidates[i, :, j]
                    backoff_candidate = next_candidates[i, :, j + 1]
                    best_log_p = log_probs[i, :, j]
                    backoff_log_p = log_probs[i, :, j + 1]
                    next_word = torch.where(best_candidate == tgt_dict.unk_idx,
                                            backoff_candidate, best_candidate)
                    log_p = torch.where(best_candidate == tgt_dict.unk_idx,
                                        backoff_log_p, best_log_p)
                    log_p = log_p[-1]
                    next_word = torch.cat((prev_words[i][1:], next_word[-1:]))

                    # Get parent node and beam search object for corresponding sentence
                    node = nodes[i]
                    search = node.search

                    # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction?

                    # Store the node as final if EOS is generated
                    if next_word[-1] == tgt_dict.eos_idx:
                        node = BeamSearchNode(
                            search, node.emb, node.lstm_out, node.final_hidden,
                            node.final_cell, node.mask,
                            torch.cat((prev_words[i][0].view([1]), next_word)),
                            node.logp, node.length)
                        search.add_final(-node.eval(), node)

                    # Add the node to current nodes for next iteration
                    else:
                        node = BeamSearchNode(
                            search, node.emb, node.lstm_out, node.final_hidden,
                            node.final_cell, node.mask,
                            torch.cat((prev_words[i][0].view([1]), next_word)),
                            node.logp + log_p, node.length + 1)
                        search.add(-node.eval(), node)

            # __QUESTION 5: What happens internally when we prune our beams?
            # How do we know we always maintain the best sequences?
            for search in searches:
                search.prune()

        # Segment into sentences
        best_sents = torch.stack(
            [search.get_best()[1].sequence[1:].cpu() for search in searches])
        decoded_batch = best_sents.numpy()

        output_sentences = [
            decoded_batch[row, :] for row in range(decoded_batch.shape[0])
        ]

        # __QUESTION 6: What is the purpose of this for loop?
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append(sent)
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        for ii, sent in enumerate(output_sentences):
            all_hyps[int(sample['id'].data[ii])] = sent

    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                out_file.write(all_hyps[sent_id] + '\n')
Esempio n. 29
0
def train(epoch, net, net2, optimizer, labeled_trainloader,
          unlabeled_trainloader):
    net.train()
    net2.eval()  #fix one network and train the other

    unlabeled_train_iter = iter(unlabeled_trainloader)
    num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1
    for batch_idx, (inputs_x, inputs_x2, labels_x,
                    w_x) in enumerate(labeled_trainloader):
        try:
            inputs_u, inputs_u2 = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            inputs_u, inputs_u2 = unlabeled_train_iter.next()
        batch_size = inputs_x.size(0)

        # Transform label to one-hot
        labels_x = torch.zeros(batch_size, args.num_class).scatter_(
            1, labels_x.view(-1, 1), 1)
        w_x = w_x.view(-1, 1).type(torch.FloatTensor)

        inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(
        ), labels_x.cuda(), w_x.cuda()
        inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()

        with torch.no_grad():
            # label co-guessing of unlabeled samples
            outputs_u11 = net(inputs_u)
            outputs_u12 = net(inputs_u2)
            outputs_u21 = net2(inputs_u)
            outputs_u22 = net2(inputs_u2)
            #Para_x, outputs_x = LDAM_DRE(outputs_x, targets_x, weight)
            pu = (torch.softmax(1 * outputs_u11, dim=1) +
                  torch.softmax(1 * outputs_u12, dim=1) +
                  torch.softmax(1 * outputs_u21, dim=1) +
                  torch.softmax(1 * outputs_u22, dim=1)) / 4  #所有概率上求平均
            ptu = pu**(1 / args.T)  # temparature sharpening

            targets_u = ptu / ptu.sum(dim=1, keepdim=True)  # normalize
            targets_u = targets_u.detach()

            # label refinement of labeled samples
            outputs_x = net(inputs_x)
            outputs_x2 = net(inputs_x2)

            px = (torch.softmax(1 * outputs_x, dim=1) +
                  torch.softmax(1 * outputs_x2, dim=1)) / 2
            px = w_x * labels_x + (1 - w_x) * px  # one hot
            ptx = px**(1 / args.T)  # temparature sharpening

            targets_x = ptx / ptx.sum(
                dim=1, keepdim=True)  # normalize           如果不是这一类的话会被弱化
            targets_x = targets_x.detach()

        # mixmatch
        l = np.random.beta(args.alpha, args.alpha)  #训练数据的不同也保证了网络的divergence
        l = max(l, 1 - l)

        all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2],
                               dim=0)  # batch_size*2=128
        all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u],
                                dim=0)

        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        logits = net(mixed_input)
        logits_x = logits[:batch_size * 2]
        logits_u = logits[batch_size * 2:]

        Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size * 2],
                                 logits_u, mixed_target[batch_size * 2:],
                                 epoch + batch_idx / num_iter, warm_up)

        # regularization
        if args.noise_mode == 'asym_two_unbalanced_classes':  #网络会被不平衡的分布严重影响 预测概率会由【0.9 0.1】趋近【1 0】
            #pred_mean = torch.softmax(logits, dim=1).mean(0)  # 保留dim=1的维度的 输出概率更像是分布估计
            #print("pred_mean=",pred_mean)
            #prior = torch.ones(args.num_class) / args.num_class
            #prior = torch.tensor([0.9, 0.1])
            #prior = prior.cuda()
            #penalty = torch.sum(prior * torch.log(prior / pred_mean))  # 逐个元素相乘积
            penalty = 0
        #elif args.noise_mode == 'asym_two_unbalanced_classes_origin':
        else:
            prior = torch.ones(args.num_class) / args.num_class
            prior = prior.cuda()
            pred_mean = torch.softmax(logits,
                                      dim=1).mean(0)  #保留dim=1的维度的 输出概率更像是分布估计
            penalty = torch.sum(prior * torch.log(prior / pred_mean))  #逐个元素相乘积

        loss = Lx + lamb * Lu + penalty
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        sys.stdout.write('\r')
        sys.stdout.write(
            '%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.2f  Unlabeled loss: %.2f'
            % (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs,
               batch_idx + 1, num_iter, Lx.item(), Lu.item()))
        sys.stdout.flush()
zs = []
z_tildes=[]
n_samps = 1000
count=0
while count < n_samps:
    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)
    if b==0:
        zs.append(z)
        count+=1

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits,dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
        z_tildes.append(z_tilde)


zs = torch.stack(zs).view(n_samps,C)
z_tildes = torch.stack(z_tildes).view(n_samps,C)
# print (zs.shape)
# fsdfasd

n_bins = 80

rows = C
cols = 2
fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)
Esempio n. 31
0
def train(args):
    print(args)

    dataset = DataSetLoader(args.data_name,
                            args.device,
                            use_one_hot_fea=args.use_one_hot_fea,
                            symm=args.gcn_agg_norm_symm,
                            test_ratio=args.data_test_ratio,
                            valid_ratio=args.data_valid_ratio)

    #dataset = MovieLens(args.data_name, args.device, use_one_hot_fea=args.use_one_hot_fea, symm=args.gcn_agg_norm_symm,
    #                    test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio, sparse_ratio = args.sparse_ratio)
    print("Loading data finished ...\n")

    args.src_in_units = dataset.user_feature_shape[1]
    args.dst_in_units = dataset.movie_feature_shape[1]
    args.rating_vals = dataset.possible_rating_values

    ### build the net
    net = Net(args=args)
    args.decoder = "MLP"
    net = net.to(args.device)
    nd_possible_rating_values = th.FloatTensor(
        dataset.possible_rating_values).to(args.device)
    rating_loss_net = nn.CrossEntropyLoss()
    learning_rate = args.train_lr
    optimizer = get_optimizer(args.train_optimizer)(net.parameters(),
                                                    lr=learning_rate)
    print("Loading network finished ...\n")

    ### perpare training data
    train_gt_labels = dataset.train_labels

    train_gt_ratings = dataset.train_truths

    ### prepare the logger
    train_loss_logger = MetricLogger(
        ['iter', 'loss', 'rmse'], ['%d', '%.4f', '%.4f'],
        os.path.join(args.save_dir, 'train_loss%d.csv' % args.save_id))
    valid_loss_logger = MetricLogger(['iter', 'rmse'], ['%d', '%.4f'],
                                     os.path.join(
                                         args.save_dir,
                                         'valid_loss%d.csv' % args.save_id))
    test_loss_logger = MetricLogger(['iter', 'rmse'], ['%d', '%.4f'],
                                    os.path.join(
                                        args.save_dir,
                                        'test_loss%d.csv' % args.save_id))

    ### declare the loss information
    best_valid_rmse = np.inf
    no_better_valid = 0
    best_iter = -1
    count_rmse = 0
    count_num = 0
    count_loss = 0

    dataset.train_enc_graph = dataset.train_enc_graph.int().to(args.device)
    dataset.train_dec_graph = dataset.train_dec_graph.int().to(args.device)
    dataset.valid_enc_graph = dataset.train_enc_graph
    dataset.valid_dec_graph = dataset.valid_dec_graph.int().to(args.device)
    dataset.test_enc_graph = dataset.test_enc_graph.int().to(args.device)
    dataset.test_dec_graph = dataset.test_dec_graph.int().to(args.device)
    print("Start training ...")
    dur = []
    for iter_idx in range(1, args.train_max_iter):
        '''
        noisy_labels = th.LongTensor(np.random.choice([-1, 0, 1], train_gt_ratings.shape[0], replace=True, p=[0.001, 0.998, 0.001])).to(args.device)

        train_gt_labels += noisy_labels
    
        max_label = dataset.max_l + th.zeros_like(train_gt_labels)
        min_label = dataset.min_l + th.zeros_like(train_gt_labels)
        max_label = max_label.long()
        min_label = min_label.long()
        train_gt_labels = th.where(train_gt_labels > max_label, max_label, train_gt_labels)
        train_gt_labels = th.where(train_gt_labels < min_label, min_label, train_gt_labels)
        '''
        if iter_idx > 3:
            t0 = time.time()
        net.train()
        if iter_idx > 250:
            Two_Stage = True
        else:
            Two_Stage = False
        Two_Stage = False
        pred_ratings, reg_loss = net(dataset.train_enc_graph,
                                     dataset.train_dec_graph,
                                     dataset.user_feature,
                                     dataset.movie_feature, Two_Stage)
        if args.loss_func == "CE":
            loss = rating_loss_net(
                pred_ratings, train_gt_labels).mean() + args.ARR * reg_loss
            '''
            real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
                                nd_possible_rating_values.view(1, -1)).sum(dim=1)
            mse_loss = th.sum((real_pred_ratings - train_gt_ratings) ** 2)
            loss += mse_loss * 0.0001
            '''
        elif args.loss_func == "Hinge":
            real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
                                 nd_possible_rating_values.view(1, -1)).sum(
                                     dim=1)
            gap = (real_pred_ratings - train_gt_labels)**2
            hinge_loss = th.where(gap > 1.0, gap * gap, gap).mean()
            loss = hinge_loss
        elif args.loss_func == "MSE":
            '''
            seeds = th.arange(pred_ratings.shape[0])
            random.shuffle(seeds)
            for i in range((pred_ratings.shape[0] - 1) // 50 + 1):
                start = i * 50
                end = (i + 1) * 50
                if end > (pred_ratings.shape[0] - 1):
                    end = pred_ratings.shape[0] - 1
                batch = seeds[start:end]
                loss = F.mse_loss(pred_ratings[batch, 0], nd_possible_rating_values[train_gt_labels[batch]]) + args.ARR * reg_loss
                count_loss += loss.item() * 50 / pred_ratings.shape[0]
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                #nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
                optimizer.step()
                pred_ratings, reg_loss = net(dataset.train_enc_graph, dataset.train_dec_graph,
                                   dataset.user_feature, dataset.movie_feature)
            '''
            loss = th.mean((pred_ratings[:, 0] -
                            nd_possible_rating_values[train_gt_labels])**
                           2) + args.ARR * reg_loss
        count_loss += loss.item()
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
        optimizer.step()

        if iter_idx > 3:
            dur.append(time.time() - t0)

        if iter_idx == 1:
            print("Total #Param of net: %d" % (torch_total_param_num(net)))
            print(
                torch_net_info(net,
                               save_path=os.path.join(
                                   args.save_dir, 'net%d.txt' % args.save_id)))

        if args.loss_func == "CE":
            real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
                                 nd_possible_rating_values.view(1, -1)).sum(
                                     dim=1)
        elif args.loss_func == "MSE":
            real_pred_ratings = pred_ratings[:, 0]
        rmse = ((real_pred_ratings - train_gt_ratings)**2).sum()
        count_rmse += rmse.item()
        count_num += pred_ratings.shape[0]

        if iter_idx % args.train_log_interval == 0:
            train_loss_logger.log(iter=iter_idx,
                                  loss=count_loss / (iter_idx + 1),
                                  rmse=count_rmse / count_num)
            logging_str = "Iter={}, loss={:.4f}, rmse={:.4f}, time={:.4f}".format(
                iter_idx, count_loss / iter_idx, count_rmse / count_num,
                np.average(dur))
            count_rmse = 0
            count_num = 0

        if iter_idx % args.train_valid_interval == 0:
            valid_rmse = evaluate(args=args,
                                  net=net,
                                  dataset=dataset,
                                  segment='valid')
            valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse)
            test_rmse = evaluate(args=args,
                                 net=net,
                                 dataset=dataset,
                                 segment='test')
            logging_str += ', Test RMSE={:.4f}'.format(test_rmse)
            test_loss_logger.log(iter=iter_idx, rmse=test_rmse)
            logging_str += ',\tVal RMSE={:.4f}'.format(valid_rmse)

            if valid_rmse < best_valid_rmse:
                best_valid_rmse = valid_rmse
                no_better_valid = 0
                best_iter = iter_idx
                test_rmse = evaluate(args=args,
                                     net=net,
                                     dataset=dataset,
                                     segment='test',
                                     debug=True,
                                     idx=iter_idx)
                best_test_rmse = test_rmse
                test_loss_logger.log(iter=iter_idx, rmse=test_rmse)
                logging_str += ', Test RMSE={:.4f}'.format(test_rmse)
            else:
                no_better_valid += 1
                if no_better_valid > args.train_early_stopping_patience\
                    and learning_rate <= args.train_min_lr:
                    logging.info(
                        "Early stopping threshold reached. Stop training.")
                    break
                if no_better_valid > args.train_decay_patience:
                    new_lr = max(learning_rate * args.train_lr_decay_factor,
                                 args.train_min_lr)
                    if new_lr < learning_rate:
                        learning_rate = new_lr
                        logging.info("\tChange the LR to %g" % new_lr)
                        for p in optimizer.param_groups:
                            p['lr'] = learning_rate
                        no_better_valid = 0
        if iter_idx % args.train_log_interval == 0:
            print(logging_str)
    print('Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'.
          format(best_iter, best_valid_rmse, best_test_rmse))
    train_loss_logger.close()
    valid_loss_logger.close()
    test_loss_logger.close()
Esempio n. 32
0
def train(method, n_components, true_mixture_weights, exp_dir, needsoftmax_mixtureweight=None):

    print('Method:', method)
    C = n_components

    true_mixture_weights = torch.tensor(true_mixture_weights, 
                                            requires_grad=True, device="cuda")

    if needsoftmax_mixtureweight is None:
        needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")
    else:
        needsoftmax_mixtureweight = torch.tensor(needsoftmax_mixtureweight, 
                                            requires_grad=True, device="cuda")
    
    lr = 1e-3
    optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=1e-4, weight_decay=1e-7)

    encoder = NN3(input_size=1, output_size=n_components, n_residual_blocks=3).cuda()
    optim_net = torch.optim.Adam(encoder.parameters(), lr=1e-4, weight_decay=1e-7)

    if method in ['simplax', 'relax']:
        surrogate = NN3(input_size=1+n_components+n_components, output_size=1, n_residual_blocks=4).cuda()
        # surrogate = NN3(input_size=1+n_components+n_components, output_size=1, n_residual_blocks=10).cuda()
        optim_surr = torch.optim.Adam(surrogate.parameters(), lr=lr)

        # surrogate.load_params_v3(save_dir=save_dir+'relax_C20_u1v1_surrloss2' +'/params/', name='surrogate', step=100000)
        # surrogate.load_params_v3(save_dir=save_dir+'relax_C6_u1v1_surrloss2' +'/params/', name='surrogate', step=100000)

    if method in ['reinforce_baseline']:
        surrogate = NN3(input_size=1, output_size=1, n_residual_blocks=4).cuda()
        optim_surr = torch.optim.Adam(surrogate.parameters(), lr=lr)        

    if method == 'HLAX':
        # surrogate = NN4(input_size=1+n_components, output_size=1, n_residual_blocks=4).cuda()
        surrogate = NN3(input_size=1+n_components, output_size=1, n_residual_blocks=4).cuda()
        surrogate2 = NN3(input_size=1, output_size=1, n_residual_blocks=2).cuda()
        optim_surr = torch.optim.Adam(surrogate.parameters(), lr=lr)
        optim_surr2 = torch.optim.Adam(surrogate2.parameters(), lr=lr)

    data_dict = {}
    data_dict['steps'] = []
    data_dict['theta_losses'] = []
    data_dict['f'] = []
    data_dict['lpx_given_z'] = []
    data_dict['lpz'] = []
    data_dict['lqz'] = []
    data_dict['inference_L2'] = []
    # data_dict['grad_var'] = []
    # data_dict['grad_avg'] = []

    if method in ['simplax', 'HLAX', 'relax']:
        data_dict['surr_loss'] = []
        data_dict['surr_dif'] = []
    if method in ['simplax', 'HLAX']:
        data_dict['grad_split'] = {}
        data_dict['grad_split']['score'] = []
        data_dict['grad_split']['path'] = []
    if method=='HLAX':
        data_dict['alpha'] = []
    if method=='relax':
        data_dict['grad_logq'] = [] 
        data_dict['surr_grads'] = {}
        data_dict['surr_grads']['grad_surr_z'] = [] 
        data_dict['surr_grads']['grad_surr_z_tilde'] = [] 
        data_dict['grad_split'] = {}
        data_dict['grad_split']['score'] = []
        data_dict['grad_split']['path'] = []
        data_dict['var'] = {}
        data_dict['var']['reinforce'] = []
        data_dict['var']['relax'] = []
        data_dict['bias'] = [] 
        data_dict['SNR'] = {}
        data_dict['SNR']['reinforce'] = []
        data_dict['SNR']['relax'] = []
    if method=='reinforce_baseline':
        data_dict['surr_loss'] = []

    cur_time = time.time()
    for step in range(0,n_steps+1):

        mixtureweights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() #[C]

        # print (true_mixture_weights)
        # print (torch.ones(C).cuda()/C)
        # fsaf
        x = sample_gmm(batch_size, mixture_weights=torch.ones(C).cuda()/C)
        # print (torch.mean((x>100).float()))
        # fasd
        # x = sample_gmm(batch_size, mixture_weights=true_mixture_weights)
        logits = encoder.net(x)
        logits = logits/100.
        # logits = logits-10.

        if method == 'reinforce':
            # net_loss, f, logpx_given_z, logpz, logq = reinforce(x, logits, mixtureweights, k=1)
            outputs = reinforce(x, logits, mixtureweights, k=1)
        elif method == 'reinforce_pz':
            net_loss, f, logpx_given_z, logpz, logq = reinforce_pz(x, logits, mixtureweights, k=1)
        elif method == 'simplax':
            # net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score = simplax(surrogate, x, logits, mixtureweights, k=1)
            outputs = simplax(surrogate, x, logits, mixtureweights, k=1)
        elif method == 'HLAX':
            net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score, alpha = HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1)
        elif method=='relax':
            outputs = relax(step, surrogate, x, logits, mixtureweights, k=1)
        elif method=='reinforce_baseline':
            outputs = reinforce_baseline(surrogate, x, logits, mixtureweights, k=1)



        if step > 300000:

            # Update generator
            loss = - torch.mean(outputs['f'])
            optim.zero_grad()
            loss.backward(retain_graph=True)  
            optim.step()

        # if step > 100000:
        if step > 100000:
            # Update encoder
            optim_net.zero_grad()
            outputs['net_loss'].backward(retain_graph=True)
            optim_net.step()

        # Update surrogate
        if method in ['simplax', 'HLAX', 'relax', 'reinforce_baseline']:
            optim_surr.zero_grad()
            outputs['surr_loss'].backward(retain_graph=True)
            optim_surr.step()

        if method == 'HLAX':
            optim_surr2.zero_grad()
            surr_loss.backward(retain_graph=True)
            optim_surr2.step()

        if step%print_steps==0:
            # print (step, to_print(net_loss), to_print(logpxz - logq), to_print(logpx_given_z), to_print(logpz), to_print(logq))

            current_theta = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
            theta_loss = L2_error(to_print2(true_mixture_weights),
                                    to_print2(current_theta))

            pz_give_x = true_posterior(x, mixture_weights=mixtureweights)
            probs = torch.softmax(logits, dim=1)
            inference_L2 = L2_batch(pz_give_x, probs)


            # #CHECK GRADS
            # outputs = reinforce(x[0].view(1,1), logits[0].view(1,C), mixtureweights, k=1000, get_grad=True)
            # print ('reinforce')
            # print (outputs['grad_avg'])
            # print (outputs['grad_std'])

            # outputs = relax(step=step, surrogate=surrogate, x=x[0].view(1,1), 
            #                 logits=logits[0].view(1,C), mixtureweights=mixtureweights, k=1000, get_grad=True)
            # print ('relax')
            # print (outputs['grad_avg'])
            # print (outputs['grad_std'])
            # print ()
            # print (to_print2(outputs['surr_dif']))
            # fadsaf



            # #Grad variance 
            # grad = torch.autograd.grad([outputs['net_loss']], [logits], create_graph=True, retain_graph=True)[0]
            # # grad_var = torch.mean(torch.std(grad, dim=0))
            # grad_avg = torch.mean(torch.abs(grad))

            # plot_posteriors2(n_components, trueposteriors=to_print2(pz_give_x), qs=to_print2(probs), exp_dir=exp_dir, name=str(step))
            # fasfs


            #CHECK GRADS
            outputs11 = reinforce(x[0].view(1,1), logits[0].view(1,C), mixtureweights, k=100, get_grad=True)
            # print ('reinforce')
            # print (outputs['grad_avg'])
            # print (outputs['grad_std'])

            outputs22 = relax(step=step, surrogate=surrogate, x=x[0].view(1,1), 
                            logits=logits[0].view(1,C), mixtureweights=mixtureweights, k=100, get_grad=True)
            # print ('relax')
            # print (outputs['grad_avg'])
            # print (outputs['grad_std'])
            # print ()
            # print (to_print2(outputs['surr_dif']))
            # print ()



            print( 
                'S:{:5d}'.format(step),
                'T:{:.2f}'.format(time.time() - cur_time),
                'Theta_loss:{:.3f}'.format(theta_loss),
                'Loss:{:.3f}'.format(to_print1(outputs['net_loss'])),
                'ELBO:{:.3f}'.format(to_print1(outputs['f'])),
                'lpx|z:{:.3f}'.format(to_print1(outputs['logpx_given_z'])),
                'lpz:{:.3f}'.format(to_print1(outputs['logpz'])),
                'lqz:{:.3f}'.format(to_print1(outputs['logq'])),
                )
            cur_time = time.time()

            if step> 0:
                data_dict['steps'].append(step)
                data_dict['theta_losses'].append(theta_loss)
                data_dict['f'].append(to_print1(outputs['f']))
                data_dict['lpx_given_z'].append(to_print1(outputs['logpx_given_z']))
                data_dict['lpz'].append(to_print1(outputs['logpz']))
                data_dict['lqz'].append(to_print1(outputs['logq']))
                data_dict['inference_L2'].append(to_print2(inference_L2))
                # data_dict['grad_var'].append(to_print2(grad_var))
                # data_dict['grad_avg'].append(to_print2(grad_avg))
                if method in ['simplax', 'HLAX', 'relax']:
                    data_dict['surr_loss'].append(to_print2(outputs['surr_loss']))
                    data_dict['surr_dif'].append(to_print2(outputs['surr_dif']))
                if method in ['simplax', 'HLAX']:
                    data_dict['grad_split']['score'].append(to_print2(outputs['grad_score']))
                    data_dict['grad_split']['path'].append(to_print2(outputs['grad_path']))
                if method == 'HLAX':
                    data_dict['alpha'].append(to_print2(alpha))
                if method == 'relax':
                    data_dict['grad_logq'].append(to_print1(outputs['grad_logq']))
                    data_dict['surr_grads']['grad_surr_z'].append(to_print1(outputs['grad_surr_z']))
                    data_dict['surr_grads']['grad_surr_z_tilde'].append(to_print1(outputs['grad_surr_z_tilde']))
                    data_dict['grad_split']['score'].append(to_print1(outputs['grad_score']))
                    data_dict['grad_split']['path'].append(to_print1(outputs['grad_path']))
                    data_dict['var']['reinforce'].append(to_print1(outputs11['grad_std']))
                    data_dict['var']['relax'].append(to_print1(outputs22['grad_std']))

                    data_dict['bias'].append(to_print1( torch.abs(outputs11['grad_avg'] - outputs22['grad_avg'])))

                    data_dict['SNR']['reinforce'].append(to_print1( torch.abs(outputs11['grad_avg'] / outputs11['grad_std'] )))
                    data_dict['SNR']['relax'].append(to_print1(  torch.abs(outputs22['grad_avg'] / outputs22['grad_std'] )))

                if method in ['reinforce_baseline']:
                    data_dict['surr_loss'].append(to_print2(outputs['surr_loss']))

            check_nan(outputs['net_loss'])



        if step%plot_steps==0 and step!=0:

            # fsdfasd

            plot_curve2(data_dict, exp_dir)

            # list_of_posteriors = []
            # for ii in range(5):
            #     pz_give_x = true_posterior(x, mixture_weights=mixtureweights)
            #     probs = torch.softmax(logits, dim=1)
            #     inference_L2 = L2_batch(pz_give_x, probs)    
            #     list_of_posteriors.append([to_print2(pz_give_x), to_print2(probs)])      

            if step% (plot_steps*2) ==0:
            # if step% plot_steps ==0:
                plot_posteriors2(n_components, trueposteriors=to_print2(pz_give_x), qs=to_print2(probs), exp_dir=images_dir, name=str(step))
            
                plot_dist2(n_components, mixture_weights=to_print2(current_theta), true_mixture_weights=to_print2(true_mixture_weights), exp_dir=images_dir, name=str(step))

                # #CHECK GRADS
                # outputs = reinforce(x[0].view(1,1), logits[0].view(1,C), mixtureweights, k=1000, get_grad=True)
                # print ('reinforce')
                # print (outputs['grad_avg'])
                # print (outputs['grad_std'])

                # outputs = relax(step=step, surrogate=surrogate, x=x[0].view(1,1), 
                #                 logits=logits[0].view(1,C), mixtureweights=mixtureweights, k=1000, get_grad=True)
                # print ('relax')
                # print (outputs['grad_avg'])
                # print (outputs['grad_std'])
                # print ()
                # print (to_print2(outputs['surr_dif']))
                # print ()

        if step % params_step==0 and step>0:

            # save_dir = home+'/Documents/Grad_Estimators/GMM/'
            with open( exp_dir+"data.p", "wb" ) as f:
                pickle.dump(data_dict, f)
            print ('saved data')

            surrogate.save_params_v3(save_dir=params_dir, name='surrogate', step=step)
Esempio n. 33
0
 def updateOutput(self, input):
     self.output = torch.softmax(input, self._get_dim(input))
     return self.output
Esempio n. 34
0
    # logprob_cluster_list = []

    x = sample_true(batch_size).cuda() #.view(1,1)

    counts = np.zeros(n_components)
    for step in range(n_steps):

        for ii in range(surrugate_steps):
            surr_loss = get_loss()
            optim_surr.zero_grad()
            surr_loss.backward()
            optim_surr.step()

        # x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)

        probs2 = probs.cpu().data.numpy()[0]
        print()
        for iii in range(len(probs2)):
            print(str(iii)+':'+str(probs2[iii]), end =" ")
        print ()

        # print (probs.shape)
        print (probs)
        # fsdf
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())

        net_loss = 0
        loss = 0
        surr_loss = 0
Esempio n. 35
0
def softmax(input, dim=-1):
    return th.softmax(input, dim=dim)
def train(config,
          model,
          criterion,
          train_loader,
          val_loader,
          parameters,
          optimizer,
          sched,
          path,
          rnn=False,
          kernel_learning=False,
          model_path=None):
    folder_ = path

    attn = config['ATTENTION'] == 'attn'
    seq_len = config['SEQ_LEN']
    num_classes = config['NUM_CLASSES']

    if kernel_learning:
        likelihood = model['likelihood']
        model = model['model']

        if model_path:
            state_dict = torch.load(model_path)
            model.load_state_dict(state_dict['model'])
            likelihood.load_state_dict(state_dict['likelihood'])

    mean_train_losses = []
    mean_val_losses = []

    mean_train_acc = []
    mean_val_acc = []

    minLoss = np.inf
    maxValacc = -np.inf
    for epoch in range(500):
        print('EPOCH: ', epoch + 1)
        train_acc = []
        val_acc = []

        running_loss = 0.0

        model.train()

        if kernel_learning:
            likelihood.train()

        count = 0
        num_samples = 8
        with gpytorch.settings.num_likelihood_samples(
                num_samples) if kernel_learning else suppress():
            for sequences, labels in train_loader:
                labels = labels.squeeze()
                sequences = Variable(sequences.cuda())
                labels = Variable(labels.cuda())

                if kernel_learning and attn:
                    outputs, attn_weights = model(sequences)

                    batch_size = sequences.size(0)
                    # repeat label for every element in sequence
                    ext_labels = labels.repeat_interleave(seq_len)
                else:
                    outputs = model(sequences)
                    ext_labels = labels

                optimizer.zero_grad()
                if kernel_learning:
                    loss = -criterion(outputs, ext_labels)
                else:
                    loss = criterion(outputs, labels)

                if kernel_learning:
                    # This gives us 8 samples from the predictive distribution
                    # Take the mean over all samples
                    outputs = likelihood(outputs).probs.mean(0)

                    if attn:
                        outputs = outputs.reshape(
                            (batch_size, seq_len, num_classes))
                        outputs = torch.bmm(attn_weights, outputs).squeeze()

                else:
                    outputs = torch.softmax(outputs, dim=-1)

                train_acc.append(accuracy(outputs, labels))

                loss.backward()

                if rnn or kernel_learning:
                    if kernel_learning:
                        params = model.feature_extractor.parameters()
                    else:
                        params = model.parameters()
                    torch.nn.utils.clip_grad_norm_(params, 5)

                optimizer.step()

                running_loss += loss.item()
                count += 1

        # 25
        if epoch % 50 == 0:
            sched.step()

        print('Training loss:.......', running_loss / count)
        mean_train_losses.append(running_loss / count)

        model.eval()

        if kernel_learning:
            likelihood.eval()

        count = 0
        val_running_loss = 0.0
        with torch.no_grad(), gpytorch.settings.num_likelihood_samples(
                16) if kernel_learning else suppress():
            for sequences, labels in val_loader:
                labels = labels.squeeze()
                sequences = Variable(sequences.cuda())
                labels = Variable(labels.cuda())

                if kernel_learning and attn:
                    outputs, attn_weights = model(sequences)
                else:
                    outputs = model(sequences)

                if kernel_learning:
                    batch_size = sequences.size(0)
                    # This gives us 16 samples from the predictive distribution
                    # Take the mean over all samples
                    outputs = likelihood(outputs).probs.mean(0)

                    if attn:
                        outputs = outputs.reshape(
                            (batch_size, seq_len, num_classes))
                        outputs = torch.bmm(attn_weights, outputs).squeeze()

                else:
                    outputs = torch.softmax(outputs, dim=1)

                val_acc.append(accuracy(outputs, labels))
                val_running_loss += loss.item()

                count += 1

        mean_val_loss = val_running_loss / count
        print('Validation loss:.....', mean_val_loss)

        print('Training accuracy:...', np.mean(train_acc))
        print('Validation accuracy..', np.mean(val_acc))

        mean_val_losses.append(mean_val_loss)

        mean_train_acc.append(np.mean(train_acc))

        val_acc_ = np.mean(val_acc)
        mean_val_acc.append(val_acc_)

        if mean_val_loss < minLoss:
            if kernel_learning:
                state_dict = {
                    'model': model.state_dict(),
                    'likelihood': likelihood.state_dict()
                }
                torch.save(state_dict, './' + folder_ + '/_loss.pth')
            else:
                torch.save(model.state_dict(), './' + folder_ + '/_loss.pth')
            print(
                f'NEW BEST LOSS_: {mean_val_loss} ........old best:{minLoss}')
            minLoss = mean_val_loss
            print('')

        if val_acc_ > maxValacc:
            if kernel_learning:
                state_dict = {
                    'model': model.state_dict(),
                    'likelihood': likelihood.state_dict()
                }
                torch.save(state_dict, './' + folder_ + '/_acc.pth')
            else:
                torch.save(model.state_dict(), './' + folder_ + '/_acc.pth')
            print(f'NEW BEST ACC_: {val_acc_} ........old best:{maxValacc}')
            maxValacc = val_acc_

        if epoch % 500 == 0:
            if kernel_learning:
                state_dict = {
                    'model': model.state_dict(),
                    'likelihood': likelihood.state_dict()
                }
                torch.save(state_dict,
                           './' + folder_ + '/save_' + str(epoch) + '.pth')
            else:
                torch.save(model.state_dict(),
                           './' + folder_ + '/save_' + str(epoch) + '.pth')
            print(f'DIV 200: Val_acc: {val_acc_} ..Val_loss:{mean_val_loss}')

    if kernel_learning:
        state_dict = {
            'model': model.state_dict(),
            'likelihood': likelihood.state_dict()
        }
        torch.save(state_dict, './' + folder_ + '/_last.pth')
    else:
        torch.save(model.state_dict(), './' + folder_ + '/_last.pth')

    train_acc_series = pd.Series(mean_train_acc)
    val_acc_series = pd.Series(mean_val_acc)
    train_acc_series.plot(label="train")
    val_acc_series.plot(label="validation")
    plt.legend()
    plt.savefig('./train_acc.png')

    plt.clf()

    train_acc_series = pd.Series(mean_train_losses)
    val_acc_series = pd.Series(mean_val_losses)
    train_acc_series.plot(label="train")
    val_acc_series.plot(label="validation")
    plt.legend()
    plt.savefig('./train_loss.png')
Esempio n. 37
0
    def forward(self, x, t_feats):
        t = 0.1
        s_ratio = 1.0
        kd_feat_loss = 0
        kd_channel_loss = 0
        kd_spatial_loss = 0
        losses = {}
        #   for channel attention
        c_t = 0.1
        c_s_ratio = 1.0

        x = super().forward(x)

        for _i in range(len(x)):
            t_attention_mask = torch.mean(torch.abs(t_feats[_i]), [1],
                                          keepdim=True)
            size = t_attention_mask.size()
            t_attention_mask = t_attention_mask.view(x[0].size(0), -1)
            t_attention_mask = torch.softmax(t_attention_mask / t,
                                             dim=1) * size[-1] * size[-2]
            t_attention_mask = t_attention_mask.view(size)

            s_attention_mask = torch.mean(torch.abs(x[_i]), [1], keepdim=True)
            size = s_attention_mask.size()
            s_attention_mask = s_attention_mask.view(x[0].size(0), -1)
            s_attention_mask = torch.softmax(s_attention_mask / t,
                                             dim=1) * size[-1] * size[-2]
            s_attention_mask = s_attention_mask.view(size)

            c_t_attention_mask = torch.mean(torch.abs(t_feats[_i]), [2, 3],
                                            keepdim=True)  # 2 x 256 x 1 x1
            c_size = c_t_attention_mask.size()
            c_t_attention_mask = c_t_attention_mask.view(x[0].size(0),
                                                         -1)  # 2 x 256
            c_t_attention_mask = torch.softmax(c_t_attention_mask / c_t,
                                               dim=1) * 256
            c_t_attention_mask = c_t_attention_mask.view(
                c_size)  # 2 x 256 -> 2 x 256 x 1 x 1

            c_s_attention_adapted = self.channel_wise_adaptation[_i](
                torch.mean(
                    torch.abs(x[_i]),
                    [2, 3],
                ))
            c_s_attention_adapted = c_s_attention_adapted.view(
                x[0].size(0), -1, 1, 1)
            # c_s_attention_mask = torch.mean(torch.abs(x[_i]), [2, 3], keepdim=True)  # 2 x 256 x 1 x1
            c_size = c_s_attention_adapted.size()
            c_s_attention_adapted = c_s_attention_adapted.view(
                x[0].size(0), -1)  # 2 x 256
            c_s_attention_mask = torch.softmax(c_s_attention_adapted / c_t,
                                               dim=1) * 256
            c_s_attention_mask = c_s_attention_mask.view(
                c_size)  # 2 x 256 -> 2 x 256 x 1 x 1

            sum_attention_mask = (t_attention_mask +
                                  s_attention_mask * s_ratio) / (1 + s_ratio)
            sum_attention_mask = sum_attention_mask.detach()

            c_sum_attention_mask = (c_t_attention_mask + c_s_attention_mask *
                                    c_s_ratio) / (1 + c_s_ratio)
            c_sum_attention_mask = c_sum_attention_mask.detach()

            kd_feat_loss += dist2(
                t_feats[_i],
                self.adaptation_layers[_i](x[_i]),
                attention_mask=sum_attention_mask,
                channel_attention_mask=c_sum_attention_mask) * 7e-5 * 6

            kd_channel_loss += torch.dist(
                torch.mean(torch.abs(t_feats[_i]), [2, 3]),
                c_s_attention_adapted) * 4e-3 * 6

            t_spatial_pool = torch.mean(t_feats[_i],
                                        [1]).view(t_feats[_i].size(0), 1,
                                                  t_feats[_i].size(2),
                                                  t_feats[_i].size(3))
            s_spatial_pool = torch.mean(x[_i],
                                        [1]).view(x[_i].size(0), 1,
                                                  x[_i].size(2), x[_i].size(3))
            kd_spatial_loss += torch.dist(
                t_spatial_pool,
                self.spatial_wise_adaptation[_i](s_spatial_pool)) * 4e-3 * 6

        losses.update({'kd_feat_loss': kd_feat_loss})
        losses.update({'kd_channel_loss': kd_channel_loss})
        losses.update({'kd_spatial_loss': kd_spatial_loss})

        kd_nonlocal_loss = 0
        for _i in range(len(x)):
            s_relation = self.student_non_local[_i](x[_i])
            t_relation = self.teacher_non_local[_i](t_feats[_i])
            kd_nonlocal_loss += torch.dist(
                self.non_local_adaptation[_i](s_relation), t_relation, p=2)
        losses.update(kd_nonlocal_loss=kd_nonlocal_loss * 7e-5 * 6)
        return losses
Esempio n. 38
0
    def evaluate(self,
                 add_threshold: float = 0.1,
                 proj_threshold: int = 5,
                 angle_threshold: float = 5.,
                 trans_threshold: float = 5.) -> Dict:
        """
        Evaluate the whole network model. Evaluate one batch of testing image. Calculate several metrics
        :param add_threshold: the ADD(-S) metrics threshold. ADD(-s) error smaller than 'threshold * diameter' is considered correct, default=0.1
        :param proj_threshold: the projection error [px] threshold. Smaller than threshold is considered correct
        :param angle_threshold: rotation error threshold in in 5cm5°
        :param trans_threshold: translation error threshold in 5cm5°,
        :return: accuracy dict with keys: 'add', 'add-s', '5cm5degree', 'projection', 'miou'. Containing accuracy of all metrics
        """

        image_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=constants.IMAGE_MEAN,
                                 std=constants.IMAGE_STD)
        ])
        linemod_dataset = Linemod(train=False, transform=image_transform)
        dataloader = Data.DataLoader(linemod_dataset,
                                     batch_size=2,
                                     pin_memory=True)

        # init metrics storing space
        miou_list = list()
        add_acc_list = list()
        adds_acc_list = list()
        rot_tra_arr_list = list()
        projection_acc_list = list()

        add_acc_list_refined = list()
        adds_acc_list_refined = list()
        rot_tra_arr_list_refined = list()
        projection_acc_list_refined = list()

        for i, data in enumerate(dataloader):
            mask_label: torch.Tensor
            img, mask_label, vmap_label, label, img_path = data  # all torch.Tensor
            img, label = img.to(self.device), label.to(self.device)
            batch_size = img.shape[0]
            out: Tuple = self.network(img)
            pred_masks: torch.Tensor = out[
                0]  # shape (n, 13+1(or 1+1=2), h, w)
            pred_vmap: np.ndarary = out[1].cpu().detach().numpy(
            )  # shape (n, 234(or 18), h, w)
            # use softmax and argmax to find the probability
            pred_masks: torch.Tensor = torch.softmax(pred_masks, dim=1)
            binary_masks: torch.Tensor = pred_masks.argmax(
                dim=1)  # shape (n, h, w)
            # convert the pred_mask to binary mask, shape (n, h, w)
            binary_masks: np.ndarray = torch.where(
                binary_masks == label[0],
                torch.tensor(1).to(self.device),
                torch.tensor(0).to(self.device)).cpu().detach().numpy()
            # convert tensor non-binary mask gt to numpy binary mask gt, shape (n, h, w)
            binary_mask_gt: np.ndarray = torch.where(
                mask_label.cpu() == label[0].cpu().item(), torch.tensor(1.),
                torch.tensor(0.)).cpu().detach().numpy()
            # metrics:
            # mask iou
            miou = metrics.mask_miou(binary_masks, binary_mask_gt)
            miou_list.append(miou)
            # process every image in batch one by one
            for j in range(batch_size):
                if not self.simple:
                    obj_vmap: np.ndarray = LinemodOutputExtractor.extract_vector_field_by_name(
                        pred_vmap[j], name=self.category)
                else:
                    obj_vmap: np.ndarray = pred_vmap[j]
                # shape (9, 2), the first point is the location of object center
                pred_keypoints: np.nadarry = self.voting_procedure.provide_keypoints(
                    binary_masks[j], obj_vmap)
                # predicted pose, (3, 4)
                pred_pose: np.ndarray = geometry_utils.solve_pnp(
                    object_pts=self.model_keypoints,
                    image_pts=pred_keypoints
                    if self.need_model_origin else pred_keypoints[1:],
                    camera_k=regular_config.camera)  # shape (3, 4)
                # Load the GT rotation and translation
                gt_pose: np.ndarray = LinemodDatasetProvider.provide_pose(
                    img_path[j])  # shape (3, 4)
                image_label_path = img_path[j].replace('JPEGImages', 'labels')
                image_label_path = image_label_path.replace('jpg', 'txt')
                gt_keypoints = LinemodDatasetProvider.provide_keypoints_coordinates(
                    image_label_path)[1].numpy()
                print('GT Keypoints: \n', gt_keypoints)
                print('Pred Keypoints: \n', pred_keypoints)

                if self.refinement is not None:
                    # do the icp process
                    depth_arr: np.ndarray = LinemodDatasetProvider.provide_depth(
                        img_path[j])
                    pred_pose_refined: np.ndarray = self.refinement.refine(
                        depth=depth_arr,
                        mask=binary_masks[j],
                        pose=pred_pose,
                        model_pts=self.model_points)
                    add_error_refined: float = metrics.calculate_add(
                        pred_pose_refined, gt_pose, self.model_points)
                    adds_error_refined: float = metrics.calculate_add_s(
                        pred_pose_refined, gt_pose, self.model_points)
                    rot_error_refined: float = metrics.rotation_error(
                        pred_pose_refined[:, :3], gt_pose[:, :3])
                    tra_error_refined: float = metrics.translation_error(
                        pred_pose_refined[:, -1], gt_pose[:, -1])
                    proj_error_refined: float = metrics.projection_error(
                        pts_3d=self.model_points,
                        camera_k=regular_config.camera,
                        pred_pose=pred_pose_refined,
                        gt_pose=gt_pose)

                    # according to the errors, determine the pose is correct or not
                    add_acc_list_refined.append(
                        add_error_refined < add_threshold * self.diameter)
                    adds_acc_list_refined.append(
                        adds_error_refined < add_threshold * self.diameter)
                    rot_tra_arr_list_refined.append(
                        rot_error_refined < angle_threshold
                        and tra_error_refined < trans_threshold)
                    projection_acc_list_refined.append(
                        proj_error_refined < proj_threshold)

                # calculate all kinds of errors
                add_error: float = metrics.calculate_add(
                    pred_pose, gt_pose, self.model_points)
                adds_error: float = metrics.calculate_add_s(
                    pred_pose, gt_pose, self.model_points)
                rot_error: float = metrics.rotation_error(
                    pred_pose[:, :3], gt_pose[:, :3])
                tra_error: float = metrics.translation_error(
                    pred_pose[:, -1], gt_pose[:, -1])
                proj_error: float = metrics.projection_error(
                    pts_3d=self.model_points,
                    camera_k=regular_config.camera,
                    pred_pose=pred_pose,
                    gt_pose=gt_pose)

                # rewrite the check_pose_correct function
                add_acc_list.append(add_error < add_threshold * self.diameter)
                adds_acc_list.append(
                    adds_error < add_threshold * self.diameter)
                rot_tra_arr_list.append((rot_error < angle_threshold)
                                        and (tra_error < trans_threshold))
                projection_acc_list.append(proj_error < proj_threshold)

        # summary all metrics
        accuracies: Dict = self.summary(add_acc_list, adds_acc_list,
                                        rot_tra_arr_list, projection_acc_list,
                                        miou_list)

        if self.refinement is not None:
            accuracies_refined: Dict = self.summary(
                add_acc_list_refined, adds_acc_list_refined,
                rot_tra_arr_list_refined, projection_acc_list_refined,
                miou_list)
            print('Accuracy with refinement: \n', accuracies_refined)

        # add threshold info
        accuracies['add_thres'] = add_threshold
        accuracies['proj_thres'] = proj_threshold
        accuracies['rot_thres'] = angle_threshold
        accuracies['tra_thres'] = trans_threshold

        # print('Accuracy: \n', accuracies)

        return accuracies
Esempio n. 39
0
def train(args, model_teacher, model_student, classifier_teacher, classifier_student, train_labeled_loader, train_unlabeled_loader, optimizer, epoch):

    model_teacher.eval()
    classifier_teacher.eval()

    model_student.train()
    classifier_student.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    acc = AverageMeter()

    end = time.time()

    train_loader = zip(train_labeled_loader, train_unlabeled_loader)

    for batch_idx, (data_x, data_u) in enumerate(tqdm(train_loader, disable=False)):

        # Get inputs and target
        inputs_x, targets_x = data_x
        inputs_u_w, inputs_u_s = data_u

        inputs_x, inputs_u_w, inputs_u_s, targets_x = inputs_x.float(), inputs_u_w.float(), inputs_u_s.float(), targets_x.long()

        # Move the variables to Cuda
        inputs_x, inputs_u_w, inputs_u_s, targets_x = inputs_x.cuda(), inputs_u_w.cuda(), inputs_u_s.cuda(), targets_x.cuda()

        # Compute output
        inputs_x = inputs_x.reshape(-1, 3, 256, 256)  #Reshape
        targets_x = targets_x.reshape(-1, )

        # Compute pseudolabels for weak_unlabeled images using the teacher model
        with torch.no_grad():
            feat_u_w = model_teacher(inputs_u_w)  # weak unlabeled data
            logits_u_w = classifier_teacher(feat_u_w)

        # Compute output for labeled and strong_unlabeled images using the student model
        inputs = torch.cat((inputs_x, inputs_u_s))
        feats = model_student(inputs)
        logits = classifier_student(feats)

        batch_size = inputs_x.shape[0]
        logits_x = logits[:batch_size]  # labeled data
        logits_u_s = logits[batch_size:]  # unlabeled data
        del logits

        # Compute loss
        Supervised_loss = F.cross_entropy(logits_x, targets_x, reduction='mean')

        pseudo_label = torch.softmax(logits_u_w.detach_(), dim=-1)
        max_probs, targets_u = torch.max(pseudo_label, dim=-1)
        Consistency_loss = F.cross_entropy(logits_u_s, targets_u, reduction='mean')

        final_loss = Supervised_loss + args.lambda_u * Consistency_loss

        # compute gradient and do SGD step #############
        optimizer.zero_grad()
        final_loss.backward()
        optimizer.step()

        # compute loss and accuracy ####################
        losses_x.update(Supervised_loss.item(), batch_size)
        losses_u.update(Consistency_loss.item(), batch_size)
        losses.update(final_loss.item(), batch_size)
        pred = torch.argmax(logits_x, dim=1)
        acc.update(torch.sum(targets_x == pred).item() / batch_size, batch_size)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print statistics and write summary every N batch
        if (batch_idx + 1) % args.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'acc {acc.val:.3f} ({acc.avg:.3f})\t'
                  'final_loss {final_loss.val:.3f} ({final_loss.avg:.3f})\t'
                  'Supervised_loss {Supervised_loss.val:.3f} ({Supervised_loss.avg:.3f})\t'
                  'Consistency_loss {Consistency_loss.val:.3f} ({Consistency_loss.avg:.3f})'.format(epoch, batch_idx + 1, len(train_labeled_loader),
                                                                                                    batch_time=batch_time,
                                                                                                    data_time=data_time,
                                                                                                    acc=acc,
                                                                                                    final_loss=losses,
                                                                                                    Supervised_loss=losses_x,
                                                                                                    Consistency_loss=losses_u))

    return losses.avg, losses_x.avg, losses_u.avg, acc.avg
Esempio n. 40
0
        self.encoder = nn.Linear(28 * 28, encoding_dim)
        self.decoder = nn.Linear(encoding_dim, 28 * 28)

    def forward(self, x):
        x = x.view(batch_size, -1)
        encoded = torch.relu(self.encoder(x))
        out = torch.sigmoid(self.decoder(encoded).view(batch_size, 1, 28, 28))
        return out


## Loss function & Optimizer
net = AutoEncoder()
print(net)

fn_loss = nn.MSELoss()
fn_pred = lambda output: torch.softmax(output, dim=1)
print(fn_pred)
fn_acc = lambda pred, label: ((pred.max(dim=1)[1] == label).type(torch.float)).mean()

optim = torch.optim.Adam(net.parameters(), lr=learning_rate)

# Log
writer = SummaryWriter(log_dir=log_dir)

print(optim)

## Training
num_epoch = 1

for epoch in range(1, num_epoch + 1):
    net.train()
Esempio n. 41
0
def train(args, labeled_trainloader, unlabeled_trainloader, model, optimizer,
          ema_model, scheduler, epoch):
    if args.amp:
        from apex import amp
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    end = time.time()

    if not args.no_progress:
        p_bar = tqdm(range(args.iteration),
                     disable=args.local_rank not in [-1, 0])

    train_loader = zip(labeled_trainloader, unlabeled_trainloader)
    model.train()
    for batch_idx, (data_x, data_u) in enumerate(train_loader):
        inputs_x, targets_x = data_x
        (inputs_u_w, inputs_u_s), _ = data_u
        data_time.update(time.time() - end)
        batch_size = inputs_x.shape[0]
        inputs = torch.cat((inputs_x, inputs_u_w, inputs_u_s)).to(args.device)
        targets_x = targets_x.to(args.device)
        logits = model(inputs)
        logits_x = logits[:batch_size]
        logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
        del logits

        Lx = F.cross_entropy(logits_x, targets_x, reduction='mean')

        pseudo_label = torch.softmax(logits_u_w.detach_(), dim=-1)
        max_probs, targets_u = torch.max(pseudo_label, dim=-1)
        mask = max_probs.ge(args.threshold).float()

        Lu = (F.cross_entropy(logits_u_s, targets_u, reduction='none') *
              mask).mean()

        loss = Lx + args.lambda_u * Lu

        if args.amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        losses.update(loss.item())
        losses_x.update(Lx.item())
        losses_u.update(Lu.item())

        optimizer.step()
        scheduler.step()
        if args.use_ema:
            ema_model.update(model)
        model.zero_grad()

        batch_time.update(time.time() - end)
        end = time.time()
        mask_prob = mask.mean().item()
        if not args.no_progress:
            p_bar.set_description(
                "Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. LR: {lr:.6f}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. Loss_x: {loss_x:.4f}. Loss_u: {loss_u:.4f}. Mask: {mask:.4f}. "
                .format(epoch=epoch + 1,
                        epochs=args.epochs,
                        batch=batch_idx + 1,
                        iter=args.iteration,
                        lr=scheduler.get_last_lr()[0],
                        data=data_time.avg,
                        bt=batch_time.avg,
                        loss=losses.avg,
                        loss_x=losses_x.avg,
                        loss_u=losses_u.avg,
                        mask=mask_prob))
            p_bar.update()
    if not args.no_progress:
        p_bar.close()
    return losses.avg, losses_x.avg, losses_u.avg, mask_prob
Esempio n. 42
0
 def updateOutput(self, input):
     self.output = torch.softmax(
         input,
         0 if input.dim() == 1 or input.dim() == 3 else 1
     )
     return self.output
Esempio n. 43
0
def train(method, n_components, true_mixture_weights, exp_dir, needsoftmax_mixtureweight=None):

    print('Method:', method)
    C = n_components

    true_mixture_weights = torch.tensor(true_mixture_weights, 
                                            requires_grad=True, device="cuda")

    if needsoftmax_mixtureweight is None:
        needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")
    else:
        needsoftmax_mixtureweight = torch.tensor(needsoftmax_mixtureweight, 
                                            requires_grad=True, device="cuda")
    
    # lr = 1e-3
    load_step = 0 # 95000

    optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=1e-4, weight_decay=1e-7)

    encoder = NN3(input_size=1, output_size=n_components, n_residual_blocks=3).cuda()
    # encoder.load_params_v3(save_dir=save_dir+'relax_C20_fixed2_test' +'/params/', name='encoder', step=load_step)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=1e-3, weight_decay=1e-7)
    optim_net = torch.optim.Adam(encoder.parameters(), lr=1e-4, weight_decay=1e-7)

    if method in ['simplax', 'relax']:
        surrogate = NN3(input_size=1+n_components+n_components, output_size=1, n_residual_blocks=4).cuda()
        # surrogate = NN3(input_size=1+n_components+n_components, output_size=1, n_residual_blocks=10).cuda()
        # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=5e-3)
        optim_surr = torch.optim.Adam(surrogate.parameters(), lr=1e-3)

        # surrogate.load_params_v3(save_dir=save_dir+'relax_C20_fixed2_test' +'/params/', name='surrogate', step=load_step)

    if method in ['reinforce_baseline']:
        surrogate = NN3(input_size=1, output_size=1, n_residual_blocks=4).cuda()
        optim_surr = torch.optim.Adam(surrogate.parameters(), lr=lr)        

    if method == 'HLAX':
        # surrogate = NN4(input_size=1+n_components, output_size=1, n_residual_blocks=4).cuda()
        surrogate = NN3(input_size=1+n_components, output_size=1, n_residual_blocks=4).cuda()
        surrogate2 = NN3(input_size=1, output_size=1, n_residual_blocks=2).cuda()
        optim_surr = torch.optim.Adam(surrogate.parameters(), lr=lr)
        optim_surr2 = torch.optim.Adam(surrogate2.parameters(), lr=lr)

    data_dict = {}
    data_dict['steps'] = []
    data_dict['theta_losses'] = []
    data_dict['inference_L2'] = []
    data_dict['x'] = {}
    data_dict['x']['f'] = []
    data_dict['x']['lpx_given_z'] = []
    data_dict['z'] = {}
    data_dict['z']['lpz'] = []
    data_dict['z']['lqz'] = []
    
    # data_dict['grad_var'] = []
    # data_dict['grad_avg'] = []

    if method in ['simplax', 'HLAX', 'relax']:
        data_dict['surr_loss']= {}
        data_dict['surr_loss']['single_sample'] = []
        data_dict['surr_dif'] = []
        data_dict['surr_dif2'] = []
    if method in ['simplax', 'HLAX']:
        data_dict['grad_split'] = {}
        data_dict['grad_split']['score'] = []
        data_dict['grad_split']['path'] = []
    if method=='HLAX':
        data_dict['alpha'] = []
    if method=='relax':
        data_dict['surr_loss']['actual_var'] = []
        data_dict['grad_logq'] = [] 

        data_dict['surr_grads'] = {}
        data_dict['surr_grads']['grad_surr_z'] = [] 
        data_dict['surr_grads']['grad_surr_z_tilde'] = [] 

        data_dict['grad_split'] = {}
        data_dict['grad_split']['score'] = []
        data_dict['grad_split']['path'] = []

        data_dict['var'] = {}
        data_dict['var']['reinforce'] = []
        data_dict['var']['relax'] = []
        # data_dict['var']['actual'] = []
        data_dict['grad_abs'] = [] 

        data_dict['bias'] = {}
        data_dict['bias']['reinforce_k1'] = []
        data_dict['bias']['relax_k1'] = []
        data_dict['bias']['reinforce_k100'] = []
        data_dict['bias']['relax_k100'] = []
        
        data_dict['SNR'] = {}
        data_dict['SNR']['reinforce'] = []
        data_dict['SNR']['relax'] = []

    if method=='reinforce_baseline':
        data_dict['surr_loss'] = []
    # if method=='simplax':
    #     data_dict['surr_loss']['actual_var'] = []
    #     data_dict['grad_logq'] = [] 
    #     data_dict['surr_grads'] = {}
    #     # data_dict['surr_grads']['grad_surr_z'] = [] 
    #     # data_dict['surr_grads']['grad_surr_z_tilde'] = [] 
    #     data_dict['grad_split'] = {}
    #     data_dict['grad_split']['score'] = []
    #     data_dict['grad_split']['path'] = []
    #     data_dict['var'] = {}
    #     data_dict['var']['reinforce'] = []
    #     data_dict['var']['simplax'] = []
    #     # data_dict['var']['actual'] = []
    #     data_dict['grad_abs'] = [] 
    #     data_dict['bias'] = {}
    #     data_dict['bias']['relax_k1'] = []
    #     data_dict['bias']['reinforce_k1'] = []
    #     data_dict['bias']['relax_k100'] = []
    #     data_dict['bias']['reinforce_k100'] = []
    #     data_dict['SNR'] = {}
    #     data_dict['SNR']['reinforce'] = []
    #     data_dict['SNR']['simplax'] = []




    cur_time = time.time()
    for step in range(load_step,n_steps+1):

        mixtureweights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() #[C]
        # x = sample_gmm(batch_size, mixture_weights=true_mixture_weights)
        x = sample_gmm(batch_size, mixture_weights=torch.ones(C).cuda()/C)
        prelogits = encoder.net(x)
        prelogits = prelogits * .001 # pre-condition so that it starts out with mass on most classes
        # prelogits = torch.tanh(prelogits) * 20. # avoid large grads/nans, min -20, max 20
        logits = prelogits - logsumexp(prelogits) #log of probs, probs = softmax of prelogits

        # logits = logits.clamp(-10, 10)
        logits = logits.clamp(-5, 1)

        # print (torch.softmax(logits, dim=1)[0])
        # print (torch.exp(logits)[0])



        #RUN
        if method == 'reinforce':
            # net_loss, f, logpx_given_z, logpz, logq = reinforce(x, logits, mixtureweights, k=1)
            outputs = reinforce(x, logits, mixtureweights, k=1)
        elif method == 'reinforce_pz':
            net_loss, f, logpx_given_z, logpz, logq = reinforce_pz(x, logits, mixtureweights, k=1)
        elif method == 'simplax':
            # net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score = simplax(surrogate, x, logits, mixtureweights, k=1)
            outputs = simplax(surrogate, x, logits, mixtureweights, k=1)
        elif method == 'HLAX':
            net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score, alpha = HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1)
        elif method=='relax':
            outputs = relax(step, surrogate, x, logits, mixtureweights, k=1)
        elif method=='reinforce_baseline':
            outputs = reinforce_baseline(surrogate, x, logits, mixtureweights, k=1)



        #UPDATES
        if step > 300000:

            # Update generator
            loss = - torch.mean(outputs['f'])
            optim.zero_grad()
            loss.backward(retain_graph=True)  
            optim.step()

        if step > 300000:
        # if step > 3000: # and step %2==0: #10000:
            # Update encoder
            optim_net.zero_grad()
            outputs['net_loss'].backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), .25)
            optim_net.step()

        # Update surrogate
        if method in ['simplax', 'HLAX', 'relax', 'reinforce_baseline']:
            optim_surr.zero_grad()
            outputs['surr_loss'].backward(retain_graph=True)
            optim_surr.step()

        if method == 'HLAX':
            optim_surr2.zero_grad()
            surr_loss.backward(retain_graph=True)
            optim_surr2.step()



        with torch.no_grad():



            prelogits2 = encoder.net(x)

            if (prelogits2 != prelogits2).any():
                print (prelogits)
                print (logits)
                print (torch.max(logits))
                print (torch.min(logits))
                print (prelogits2)

                prelogits2 = prelogits2 * .001 # pre-condition so that it starts out with mass on most classes
                print (prelogits2)

                logits2 = prelogits2 - logsumexp(prelogits2) #log of probs, probs = softmax of prelogits
                
                print (logits2)
                fdsafdasdf






        if step%print_steps==0:
            # print (step, to_print(net_loss), to_print(logpxz - logq), to_print(logpx_given_z), to_print(logpz), to_print(logq))

            current_theta = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
            theta_loss = L2_error(to_print(true_mixture_weights),
                                    to_print(current_theta))

            pz_give_x = true_posterior(x, mixture_weights=mixtureweights)
            probs = torch.softmax(logits, dim=1)
            inference_L2 = L2_batch(pz_give_x, probs)


            if method =='relax':
                #Compute actual variance
                var = torch.zeros(batch_size, C).cuda()
                for b_i in range(C):
                    b = torch.tensor([b_i]*batch_size).cuda()
                    grad_squared, pb = relax_grad(x, logits, b, surrogate, mixtureweights)
                    var += grad_squared * pb


                #Compute actual grad
                actual_grad = torch.zeros(batch_size, C).cuda()
                for b_i in range(C):
                    b = torch.tensor([b_i]*batch_size).cuda()
                    grad, pb = relax_grad2(x, logits, b, surrogate, mixtureweights)
                    actual_grad += grad * pb

                    # if step% (plot_steps*2) ==0:
                    #     print (b_i, -grad, pb)
                    #     print ()

                # print (actual_grad.shape)
                grad_abs = to_print_mean( torch.abs(actual_grad ))

                # RELAX GRAD k=100
                # outputs22 = relax(step=step, surrogate=surrogate, x=x[0].view(1,1), 
                #                 logits=logits[0].view(1,C), mixtureweights=mixtureweights, k=100, get_grad=True)
                outputs22 = relax(step=step, surrogate=surrogate, x=x, 
                                logits=logits, mixtureweights=mixtureweights, k=100, get_grad=True)

                bias_relax_k100 = to_print_mean( torch.abs(actual_grad - outputs22['grad_avg']))

                # RELAX GRAD k=1
                outputs_k1 = relax(step=step, surrogate=surrogate, x=x, 
                                logits=logits, mixtureweights=mixtureweights, k=1, get_grad=True)
                bias_relax_k1 = to_print_mean( torch.abs(actual_grad - outputs_k1['grad_avg']))

                # reinforce k100
                outputs11 = reinforce(x, logits, mixtureweights, k=100, get_grad=True)
                bias_reinforce_k100 = to_print_mean( torch.abs(actual_grad - outputs11['grad_avg']))

                outputs111 = reinforce(x, logits, mixtureweights, k=1, get_grad=True)
                bias_reinforce_k1 = to_print_mean( torch.abs(actual_grad - outputs111['grad_avg']))

            # # RELAX GRAD k=1
            # outputs_1 = relax(step=step, surrogate=surrogate, x=x[0].view(1,1), 
            #                 logits=logits[0].view(1,C), mixtureweights=mixtureweights, k=1, get_grad=True)

            # #CHECK GRADS
            # outputs11 = reinforce(x[0].view(1,1), logits[0].view(1,C), mixtureweights, k=100, get_grad=True)


            # print ('reinforce')
            # print (outputs['grad_avg'])
            # print (outputs['grad_std'])

            
            # print ('relax')
            # print (outputs['grad_avg'])
            # print (outputs['grad_std'])
            # print ()
            # print (to_print(outputs['surr_dif']))
            # print ()


            # #CHECK GRADS
            # outputs = reinforce(x[0].view(1,1), logits[0].view(1,C), mixtureweights, k=1000, get_grad=True)
            # print ('reinforce')
            # print (outputs['grad_avg'])
            # print (outputs['grad_std'])

            # outputs = relax(step=step, surrogate=surrogate, x=x[0].view(1,1), 
            #                 logits=logits[0].view(1,C), mixtureweights=mixtureweights, k=1000, get_grad=True)
            # print ('relax')
            # print (outputs['grad_avg'])
            # print (outputs['grad_std'])
            # print ()
            # print (to_print(outputs['surr_dif']))
            # fadsaf



            # #Grad variance 
            # grad = torch.autograd.grad([outputs['net_loss']], [logits], create_graph=True, retain_graph=True)[0]
            # # grad_var = torch.mean(torch.std(grad, dim=0))
            # grad_avg = torch.mean(torch.abs(grad))

            # plot_posteriors2(n_components, trueposteriors=to_print(pz_give_x), qs=to_print(probs), exp_dir=exp_dir, name=str(step))
            # fasfs





            print( 
                'S:{:5d}'.format(step),
                'T:{:.2f}'.format(time.time() - cur_time),
                'Theta_loss:{:.3f}'.format(theta_loss),
                'Loss:{:.3f}'.format(to_print_mean(outputs['net_loss'])),
                'ELBO:{:.3f}'.format(to_print_mean(outputs['f'])),
                'lpx|z:{:.3f}'.format(to_print_mean(outputs['logpx_given_z'])),
                'lpz:{:.3f}'.format(to_print_mean(outputs['logpz'])),
                'lqz:{:.3f}'.format(to_print_mean(outputs['logq'])),
                )
            cur_time = time.time()

            if step> 0:
                data_dict['steps'].append(step)
                data_dict['theta_losses'].append(theta_loss)
                data_dict['x']['f'].append(to_print_mean(outputs['f']))
                data_dict['x']['lpx_given_z'].append(to_print_mean(outputs['logpx_given_z']))
                data_dict['z']['lpz'].append(to_print_mean(outputs['logpz']))
                data_dict['z']['lqz'].append(to_print_mean(outputs['logq']))
                data_dict['inference_L2'].append(to_print(inference_L2))
                # data_dict['grad_var'].append(to_print(grad_var))
                # data_dict['grad_avg'].append(to_print(grad_avg))
                if method in ['simplax', 'HLAX', 'relax']:
                    data_dict['surr_loss']['single_sample'].append(to_print(outputs['surr_loss']))
                    data_dict['surr_dif'].append(to_print(outputs['surr_dif']))
                    data_dict['surr_dif2'].append(to_print(outputs['surr_dif2']))
                if method in ['simplax', 'HLAX']:
                    data_dict['grad_split']['score'].append(to_print(outputs['grad_score']))
                    data_dict['grad_split']['path'].append(to_print(outputs['grad_path']))
                if method == 'HLAX':
                    data_dict['alpha'].append(to_print(alpha))
                if method == 'relax':
                    data_dict['surr_loss']['actual_var'].append(to_print_mean(var))
                    data_dict['grad_logq'].append(to_print_mean(outputs['grad_logq']))
                    data_dict['surr_grads']['grad_surr_z'].append(to_print_mean(outputs['grad_surr_z']))
                    data_dict['surr_grads']['grad_surr_z_tilde'].append(to_print_mean(outputs['grad_surr_z_tilde']))
                    data_dict['grad_split']['score'].append(to_print_mean(outputs['grad_score']))
                    data_dict['grad_split']['path'].append(to_print_mean(outputs['grad_path']))
                    data_dict['var']['reinforce'].append(to_print_mean(outputs11['grad_std']))
                    data_dict['var']['relax'].append(to_print_mean(outputs22['grad_std']))

                    data_dict['grad_abs'].append(grad_abs)

                    data_dict['bias']['relax_k1'].append( bias_relax_k1)
                    data_dict['bias']['relax_k100'].append( bias_relax_k100)
                    data_dict['bias']['reinforce_k100'].append( bias_reinforce_k100)
                    data_dict['bias']['reinforce_k1'].append( bias_reinforce_k1)
                    # data_dict['bias'].append(to_print_mean( torch.abs(outputs11['grad_avg'] - outputs22['grad_avg'])))

                    data_dict['SNR']['reinforce'].append(to_print_mean( torch.abs(outputs11['grad_avg'] / outputs11['grad_std'] )))
                    data_dict['SNR']['relax'].append(to_print_mean(  torch.abs(outputs22['grad_avg'] / outputs22['grad_std'] )))

                # if method == 'simplax':
                #     data_dict['surr_loss']['actual_var'].append(to_print_mean(var))
                #     data_dict['grad_logq'].append(to_print_mean(outputs['grad_logq']))
                #     # data_dict['surr_grads']['grad_surr_z'].append(to_print_mean(outputs['grad_surr_z']))
                #     # data_dict['surr_grads']['grad_surr_z_tilde'].append(to_print_mean(outputs['grad_surr_z_tilde']))
                #     data_dict['grad_split']['score'].append(to_print_mean(outputs['grad_score']))
                #     data_dict['grad_split']['path'].append(to_print_mean(outputs['grad_path']))
                #     data_dict['var']['reinforce'].append(to_print_mean(outputs11['grad_std']))
                #     data_dict['var']['simplax'].append(to_print_mean(outputs22['grad_std']))

                #     data_dict['grad_abs'].append(grad_abs)

                #     data_dict['bias']['relax_k1'].append( bias_relax_k1)
                #     data_dict['bias']['relax_k100'].append( bias_relax_k100)
                #     data_dict['bias']['reinforce_k100'].append( bias_reinforce_k100)
                #     data_dict['bias']['reinforce_k1'].append( bias_reinforce_k1)
                #     # data_dict['bias'].append(to_print_mean( torch.abs(outputs11['grad_avg'] - outputs22['grad_avg'])))

                #     data_dict['SNR']['reinforce'].append(to_print_mean( torch.abs(outputs11['grad_avg'] / outputs11['grad_std'] )))
                #     data_dict['SNR']['simplax'].append(to_print_mean(  torch.abs(outputs22['grad_avg'] / outputs22['grad_std'] )))


                if method in ['reinforce_baseline']:
                    data_dict['surr_loss']['single_sample'].append(to_print(outputs['surr_loss']))

            check_nan(outputs['net_loss'])



        if step%plot_steps==0 and step!=0:

            # fsdfasd

            plot_curve2(data_dict, exp_dir)

            # list_of_posteriors = []
            # for ii in range(5):
            #     pz_give_x = true_posterior(x, mixture_weights=mixtureweights)
            #     probs = torch.softmax(logits, dim=1)
            #     inference_L2 = L2_batch(pz_give_x, probs)    
            #     list_of_posteriors.append([to_print(pz_give_x), to_print(probs)])      

            if step% (plot_steps*2) ==0:
            # if step% plot_steps ==0:
                plot_posteriors2(n_components, trueposteriors=to_print(pz_give_x), qs=to_print(probs), exp_dir=images_dir, name=str(step))
            
                plot_dist2(n_components, mixture_weights=to_print(current_theta), true_mixture_weights=to_print(true_mixture_weights), exp_dir=images_dir, name=str(step))

                # #CHECK GRADS
                # outputs = reinforce(x[0].view(1,1), logits[0].view(1,C), mixtureweights, k=1000, get_grad=True)
                # print ('reinforce')
                # print (outputs['grad_avg'])
                # print (outputs['grad_std'])

                # outputs = relax(step=step, surrogate=surrogate, x=x[0].view(1,1), 
                #                 logits=logits[0].view(1,C), mixtureweights=mixtureweights, k=1000, get_grad=True)
                # print ('relax')
                # print (outputs['grad_avg'])
                # print (outputs['grad_std'])
                # print ()
                # print (to_print(outputs['surr_dif']))
                # print ()

        if step % params_step==0 and step>0:

            # save_dir = home+'/Documents/Grad_Estimators/GMM/'
            with open( exp_dir+"data.p", "wb" ) as f:
                pickle.dump(data_dict, f)
            print ('saved data')

            surrogate.save_params_v3(save_dir=params_dir, name='surrogate', step=step)
            encoder.save_params_v3(save_dir=params_dir, name='encoder', step=step)
Esempio n. 44
0
    def pipeline(self, img_path: str) -> np.ndarray:
        """
        The pose estimation pipeline. This function takes a image path as input, output the prediction image
        :param img_path: the path of the detected image
        :return: the predicted pose
        """

        test_image = Image.open(img_path)
        image_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=constants.IMAGE_MEAN,
                                 std=constants.IMAGE_STD)
        ])
        img_arr: torch.Tensor = image_transform(test_image)[None].to(
            self.device)  # shape (1, 3, h, w)
        print('=======================================')
        print('Model loaded into {}, evaluation starts...'.format(self.device))

        inference_start_time = time.time()
        net_out = self.network(img_arr)
        inference_took_time = time.time() - inference_start_time
        print('Model inference took time: {:.6f}'.format(inference_took_time))

        pred_mask: torch.Tensor = net_out[0]
        pred_vector_map: np.ndarray = net_out[1].cpu().detach().numpy()
        # probability of the category at every pixel location
        pred_mask = torch.softmax(pred_mask, dim=1)
        # make it to binary mask, 0-background, 255-object
        binary_mask = pred_mask.argmax(dim=1, keepdim=True)[0, 0]

        if regular_config.result_path == '':
            # default path for saving the results
            mask_save_path = 'log_info/results/predicted_{}_mask.png'.format(
                self.category)
        else:
            mask_save_path = os.path.join(
                regular_config.result_path,
                'predicted_{}_mask.png'.format(self.category))

        # voting procedure
        # extract the correspondence from vector map,shape (18, h, w)
        if not self.simple:
            object_vector_map: np.ndarray = LinemodOutputExtractor.extract_vector_field_by_name(
                pred_vector_map[0], self.category)
            cls_label: int = constants.LINEMOD_OBJECTS_NAME.index(
                self.category)
        else:
            object_vector_map: np.ndarray = pred_vector_map[0]
            cls_label: int = 0
        # visualize the binary mask
        binary_mask = torch.where(binary_mask == cls_label,
                                  torch.tensor(255).to(self.device),
                                  torch.tensor(0).to(self.device))
        binary_mask_np: np.ndarray = binary_mask.cpu().detach().numpy().astype(
            np.uint8)
        # save the mask result
        Image.fromarray(binary_mask_np, 'L').save(mask_save_path)

        # from image mask to 0-1 mask
        object_binary_mask = np.where(binary_mask_np == 255, 1,
                                      0)  # shape (h, w)
        # we can get the keypoints now
        voting_start_time = time.time()
        pred_keypoints: np.ndarray = self.voting_procedure.provide_keypoints(
            object_binary_mask, object_vector_map)

        voting_took_time = time.time() - voting_start_time

        pred_pose: np.ndarray = geometry_utils.solve_pnp(
            object_pts=self.model_keypoints,
            image_pts=pred_keypoints[1:],
            camera_k=regular_config.camera)  # shape (3, 4)
        print(
            'The voting procedure took time: {:.6f}'.format(voting_took_time))
        print('Predicted Keypoints are listed below:\n', pred_keypoints)

        test_image_with_box = draw_3d_bbox(test_image, pred_keypoints[1:],
                                           'blue')

        # save the result
        box_save_path = regular_config.result_path + '/predicted_{}_3dbox.png'.format(
            self.category)
        test_image_with_box.save(box_save_path)
        total_time = inference_took_time + voting_took_time
        print('Pipeline took total time: {:.6f}'.format(total_time))
        print('=======================================')

        return pred_pose
Esempio n. 45
0
def get_confidence(model, model_input, class_num=None):
    y = torch.softmax(model(model_input), -1).detach()
    if class_num is not None:
        vals = torch.stack([y[idx, i] for idx, i in enumerate(class_num)])
        return torch.sum(vals), None
    return torch.sum(torch.max(y, dim=-1)[0]), torch.argmax(y, dim=-1)
def simplax():



    def show_surr_preds():

        batch_size = 1

        rows = 3
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        for i in range(rows):

            x = sample_true(1).cuda() #.view(1,1)
            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1)
            cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            check_nan(logprob_cluster)

            z = cluster_S

            n_evals = 40
            x1 = np.linspace(-9,205, n_evals)
            x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
            z = z.repeat(n_evals,1)
            cluster_H = cluster_H.repeat(n_evals,1)
            xz = torch.cat([z,x], dim=1) 

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz #- logprob_cluster

            surr_pred = surrogate.net(xz)
            surr_pred = surr_pred.data.cpu().numpy()
            f = f.data.cpu().numpy()

            col =0
            row = i
            # print (row)
            ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

            ax.plot(x1,surr_pred, label='Surr')
            ax.plot(x1,f, label='f')
            ax.set_title(str(cluster_H[0]))
            ax.legend()


        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_surr.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()




    def plot_dist():


        mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

        rows = 1
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


        xs = np.linspace(-9,205, 300)
        sum_ = np.zeros(len(xs))

        # C = 20
        for c in range(n_components):
            m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
            ys = []
            for x in xs:
                # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
                component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()


                ys.append(component_i)

            ys = np.reshape(np.array(ys), [-1])
            sum_ += ys
            ax.plot(xs, ys, label='')

        ax.plot(xs, sum_, label='')

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_plot_dist.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        


    def get_loss():

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)
        
        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(logpxz.detach()-surr_pred))

        return surr_loss


    def plot_posteriors(needsoftmax_mixtureweight, name=''):

        x = sample_true(1).cuda() 
        trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1).view(n_components)

        trueposterior = trueposterior.data.cpu().numpy()
        qz = probs.data.cpu().numpy()

        error = L2_mixtureweights(trueposterior,qz)
        kl = KL_mixutreweights(p=trueposterior, q=qz)


        rows = 1
        cols = 1
        fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

        width = .3
        ax.bar(range(len(qz)), trueposterior, width=width, label='True')
        ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q')
        # ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width)
        ax.legend()
        ax.grid(True, alpha=.3)
        ax.set_title(str(error) + ' kl:' + str(kl))
        ax.set_ylim(0.,1.)

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'posteriors'+name+'.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        



    def inference_error(needsoftmax_mixtureweight):

        error_sum = 0
        kl_sum = 0
        n=10
        for i in range(n):

            # if x is None:
            x = sample_true(1).cuda() 
            trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1).view(n_components)

            error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
            kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy())

            error_sum+=error
            kl_sum += kl
        
        return error_sum/n, kl_sum/n
        # fsdfa



    #SIMPLAX
    needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")#.cuda()
    
    print ('current mixuture weights')
    print (torch.softmax(needsoftmax_mixtureweight, dim=0))
    print()

    encoder = NN3(input_size=1, output_size=n_components).cuda()
    surrogate = NN3(input_size=1+n_components, output_size=1).cuda()
    # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.00004)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0004)
    # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.004)
    # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.0001)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0001)
    optim_net = torch.optim.SGD(encoder.parameters(), lr=.0001)
    # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.005)
    temp = 1.
    batch_size = 100
    n_steps = 300000
    surrugate_steps = 0
    k = 1
    L2_losses = []
    inf_losses = []
    inf_losses_kl = []
    kl_losses_2 = []
    surr_losses = []
    steps_list =[]
    grad_reparam_list =[]
    grad_reinforce_list =[]
    f_list = []
    logpxz_list = []
    logprob_cluster_list = []
    logpx_list = []
    # logprob_cluster_list = []
    for step in range(n_steps):

        for ii in range(surrugate_steps):
            surr_loss = get_loss()
            optim_surr.zero_grad()
            surr_loss.backward()
            optim_surr.step()

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        # print (probs)
        # fsdafsa
        # cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())

        cat = Categorical(probs=probs)
        # cluster = cat.sample()
        # logprob_cluster = cat.log_prob(cluster.detach())

        net_loss = 0
        loss = 0
        surr_loss = 0
        for jj in range(k):

            # cluster_S = cat.rsample()
            # cluster_H = H(cluster_S)
            cluster_H = cat.sample()

            # print (cluster_H.shape)
            # print (cluster_H[0])
            # fsad


            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.
            # logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            logprob_cluster = cat.log_prob(cluster_H.detach()).view(batch_size,1)
            # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1).detach()

            # cat = RelaxedOneHotCategorical(probs=probs.detach(), temperature=torch.tensor([temp]).cuda())
            # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1)

            check_nan(logprob_cluster)

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz - logprob_cluster
            # print (f)

            # surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
            surr_input = torch.cat([probs, x], dim=1) #[B,21]
            surr_pred = surrogate.net(surr_input)
            
            # print (f.shape)
            # print (surr_pred.shape)
            # print (logprob_cluster.shape)
            # fsadfsa
            # net_loss += - torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster + surr_pred - logprob_cluster)
            # net_loss += - torch.mean((logpxz.detach() - surr_pred.detach() - 1.) * logprob_cluster + surr_pred)
            net_loss += - torch.mean((logpxz.detach() - 1.) * logprob_cluster)
            # net_loss += - torch.mean((logpxz.detach()) * logprob_cluster - logprob_cluster)
            loss += - torch.mean(logpxz)
            surr_loss += torch.mean(torch.abs(logpxz.detach()-surr_pred))

        net_loss = net_loss/ k
        loss = loss / k
        surr_loss = surr_loss/ k



        # if step %2==0:
        # optim.zero_grad()
        # loss.backward(retain_graph=True)  
        # optim.step()

        optim_net.zero_grad()
        net_loss.backward(retain_graph=True)
        optim_net.step()

        # optim_surr.zero_grad()
        # surr_loss.backward(retain_graph=True)
        # optim_surr.step()

        # print (torch.mean(f).cpu().data.numpy())
        # plot_posteriors(name=str(step))
        # fsdf

        # kl_batch = compute_kl_batch(x,probs)


        if step%500==0:
            print (step, 'f:', torch.mean(f).cpu().data.numpy(), 'surr_loss:', surr_loss.cpu().data.detach().numpy(), 
                            'theta dif:', L2_mixtureweights(true_mixture_weights,torch.softmax(
                                        needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
            # if step %5000==0:
            #     print (torch.softmax(needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()) 
            #     # test_samp, test_cluster = sample_true2() 
            #     # print (test_cluster.cpu().data.numpy(), test_samp.cpu().data.numpy(), torch.softmax(encoder.net(test_samp.cuda().view(1,1)), dim=1))           
            #     print ()

            if step > 0:
                L2_losses.append(L2_mixtureweights(true_mixture_weights,torch.softmax(
                                            needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
                steps_list.append(step)
                surr_losses.append(surr_loss.cpu().data.detach().numpy())

                inf_error, kl_error = inference_error(needsoftmax_mixtureweight)
                inf_losses.append(inf_error)
                inf_losses_kl.append(kl_error)

                kl_batch = compute_kl_batch(x,probs,needsoftmax_mixtureweight)
                kl_losses_2.append(kl_batch)

                logpx = copmute_logpx(x, needsoftmax_mixtureweight)
                logpx_list.append(logpx)

                f_list.append(torch.mean(f).cpu().data.detach().numpy())
                logpxz_list.append(torch.mean(logpxz).cpu().data.detach().numpy())
                logprob_cluster_list.append(torch.mean(logprob_cluster).cpu().data.detach().numpy())




                # i_feel_like_it = 1
                # if i_feel_like_it:

                if len(inf_losses) > 0:
                    print ('probs', probs[0])
                    print('logpxz', logpxz[0])
                    print('pred', surr_pred[0])
                    print ('dif', logpxz.detach()[0]-surr_pred.detach()[0])
                    print ('logq', logprob_cluster[0])
                    print ('dif*logq', (logpxz.detach()[0]-surr_pred.detach()[0])*logprob_cluster[0])
                    
                    


                    output= torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster, dim=0)[0] 
                    output2 = torch.mean(surr_pred, dim=0)[0]
                    output3 = torch.mean(logprob_cluster, dim=0)[0]
                    # input_ = torch.mean(probs, dim=0) #[0]
                    # print (probs.shape)
                    # print (output.shape)
                    # print (input_.shape)
                    grad_reinforce = torch.autograd.grad(outputs=output, inputs=(probs), retain_graph=True)[0]
                    grad_reparam = torch.autograd.grad(outputs=output2, inputs=(probs), retain_graph=True)[0]
                    grad3 = torch.autograd.grad(outputs=output3, inputs=(probs), retain_graph=True)[0]
                    # print (grad)
                    # print (grad_reinforce.shape)
                    # print (grad_reparam.shape)
                    grad_reinforce = torch.mean(torch.abs(grad_reinforce))
                    grad_reparam = torch.mean(torch.abs(grad_reparam))
                    grad3 = torch.mean(torch.abs(grad3))
                    # print (grad_reinforce)
                    # print (grad_reparam)
                    # dfsfda
                    grad_reparam_list.append(grad_reparam.cpu().data.detach().numpy())
                    grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())
                    # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())

                    print ('reparam:', grad_reparam.cpu().data.detach().numpy())
                    print ('reinforce:', grad_reinforce.cpu().data.detach().numpy())
                    print ('logqz grad:', grad3.cpu().data.detach().numpy())

                    print ('current mixuture weights')
                    print (torch.softmax(needsoftmax_mixtureweight, dim=0))
                    print()

                    # print ()
                else:
                    grad_reparam_list.append(0.)
                    grad_reinforce_list.append(0.)                    



            if len(surr_losses) > 3  and step %1000==0:
                plot_curve(steps=steps_list,  thetaloss=L2_losses, 
                            infloss=inf_losses, surrloss=surr_losses,
                            grad_reinforce_list=grad_reinforce_list, 
                            grad_reparam_list=grad_reparam_list,
                            f_list=f_list, logpxz_list=logpxz_list,
                            logprob_cluster_list=logprob_cluster_list,
                            inf_losses_kl=inf_losses_kl,
                            kl_losses_2=kl_losses_2,
                            logpx_list=logpx_list)


                plot_posteriors(needsoftmax_mixtureweight)
                plot_dist()
                show_surr_preds()
                

            # print (f)
            # print (surr_pred)

            #Understand surr preds
            # if step %5000==0:

            # if step ==0:
                
                # fasdf
                





    data_dict = {}

    data_dict['steps'] = steps_list
    data_dict['losses'] = L2_losses

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    with open( exp_dir+"data_simplax.p", "wb" ) as f:
        pickle.dump(data_dict, f)
    print ('saved data')
Esempio n. 47
0
                inp = []
                inp.append(img)
                inp.append(img[::-1, ...])
                inp.append(img[:, ::-1, ...])
                inp.append(img[::-1, ::-1, ...])
                # TODO:(sujinhua) there is a trick using the transpose the picy
                inp = np.asarray(inp, dtype="float")
                inp = torch.from_numpy(inp.transpose((0, 3, 1, 2))).float()
                inp = Variable(inp).cuda()

                pred = []
                for model in models:
                    msk = model(
                        inp
                    )  # There may be some thing different from the baseline models
                    msk = torch.softmax(msk[:, :, ...], dim=1)
                    msk = msk.cpu().numpy()

                    msk[:, 0, ...] = 1 - msk[:, 0, ...]

                    pred.append(msk[0, ...])
                    pred.append(msk[1, :, ::-1, :])
                    pred.append(msk[2, :, :, ::-1])
                    pred.append(msk[3, :, ::-1, ::-1])

                pred_full = np.asarray(pred).mean(axis=0)

                msk = pred_full * 255
                msk = msk.astype("uint8").transpose(1, 2, 0)
                cv2.imwrite(
                    path.join(











# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
weights = torch.softmax(needsoftmax_mixtureweight, dim=0)



# cat = Categorical(probs= weights)
avg_samp = torch.zeros(n_cats)
grads = []
logprobgrads = []
momem = 0
for step in range(max_steps):

    weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
    cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.]))
    cluster_S = cat.sample()
    logprob = cat.log_prob(cluster_S.detach())
    cluster_H = H(cluster_S) #
Esempio n. 49
0
def answer(input_file, question_file):
    # Initialize Model
    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    warnings.filterwarnings("ignore", category=UserWarning)
    V = 1
    MODEL_PATH = 'encoder/infersent%s.pkl' % V
    params_model = {
        'bsize': 64,
        'word_emb_dim': 300,
        'enc_lstm_dim': 2048,
        'pool_type': 'max',
        'dpout_model': 0.0,
        'version': V
    }
    infersent = answer_model.InferSent(params_model)
    infersent.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))
    #infersent.load_state_dict(torch.load(MODEL_PATH))

    infersent.set_w2v_path("GloVe/glove.840B.300d.txt")

    # path relative to the answer program (not sure why tho)
    # large model (need Gb+ memory during runtime)
    wh_tokenizer = AutoTokenizer.from_pretrained("./ans_engine/wh_model")
    wh_model = AutoModelForQuestionAnswering.from_pretrained(
        "./ans_engine/wh_model").to(device)

    boolean_tokenizer = AutoTokenizer.from_pretrained(
        "./ans_engine/boolean_model")
    boolean_model = AutoModelForSequenceClassification.from_pretrained(
        "./ans_engine/boolean_model").to(device)

    sentence_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

    # load test file
    text = codecs.open(input_file, 'r', 'utf-8').read()
    #text = open_document(input_file)
    text = re.sub('\n+', '. ', text)
    sentences = sentence_tokenizer.tokenize(text)

    # setup measures
    score = {}  # {idx: (BM25, InferSent) }
    # InferSent for sentence semantic similarity
    infersent.build_vocab(sentences, tokenize=True)

    with open(question_file, 'r') as f:
        questions = f.readlines()
    sentences_embedding = infersent.encode(sentences, tokenize=True)
    tokenized_sentences = [s.split(" ") for s in sentences]
    bm25 = BM25Okapi(tokenized_sentences)

    for question in questions:
        question_embedding = infersent.encode([question], tokenize=True)[0]
        tokenized_question = question.split(" ")
        # BM25 scores
        bm_scores = preprocessing.normalize(
            bm25.get_scores(tokenized_question).reshape(1, -1), norm='l1')[0]
        # print(bm_scores)

        # Infersent mathcing scores
        infersent_scores = []
        for idx, s in enumerate(sentences_embedding):
            infersent_scores.append(
                cosine_similarity(question_embedding.reshape(1, -1),
                                  s.reshape(1, -1))[0][0])
        infersent_scores = preprocessing.normalize(
            np.array(infersent_scores).reshape(1, -1), norm='l1')[0]
        # print(infersent_scores)

        for idx in range(len(bm_scores)):
            score[idx] = (bm_scores[idx], infersent_scores[idx])

        score = {
            k: v
            for k, v in sorted(score.items(),
                               reverse=True,
                               key=lambda item: item[1][0] + item[1][1])
        }
        # print('*'*100)
        # print(score)
        # print('Question:\n', question)
        context = ''
        for ct, k in enumerate(sorted(list(score.keys())[0:4])):
            context += sentences[k]
        # for ct, k in enumerate(score.keys()):
        #     if ct == 3: #choose top ct candidate answer-sentences
        #         break
        #     context += sentences[k]
        #     # print('*'*100)
        #     # print(sentences[k])
        # #print('Context:\n', context)

        if is_binary_question(question):
            sequence = boolean_tokenizer.encode_plus(
                question, context, return_tensors="pt")['input_ids'].to(device)
            #logits = boolean_model(sequence)[0]
            #print(boolean_model(sequence))
            logits = boolean_model(sequence).logits
            probabilities = torch.softmax(logits,
                                          dim=1).detach().cpu().tolist()[0]
            proba_yes = round(probabilities[1], 2)
            proba_no = round(probabilities[0], 2)
            if proba_yes >= proba_no:
                print("Yes.")
            else:
                print("No.")
            #print(f"Yes: {proba_yes}", f"No: {proba_no}")
        else:
            inputs = wh_tokenizer.encode_plus(question,
                                              context,
                                              return_tensors="pt").to(device)
            for k in inputs:
                inputs[k] = torch.unsqueeze(inputs[k][0][:512], 0)
            #answer_start_scores, answer_end_scores = wh_model(**inputs)
            output = wh_model(**inputs)
            answer_start_scores, answer_end_scores = output.start_logits, output.end_logits
            #print(wh_model(**inputs))
            #print(answer_start_scores)
            answer_starts = torch.argsort(answer_start_scores,
                                          descending=True)[0][0:5]
            answer_ends = torch.argsort(answer_end_scores,
                                        descending=True)[0][0:5]

            #print('Answer(s):')
            for i in range(len(answer_starts)):
                if i == 1:
                    break
                answer_start = answer_starts[i]
                answer_end = answer_ends[i] + 1
                answer = wh_tokenizer.convert_tokens_to_string(
                    wh_tokenizer.convert_ids_to_tokens(
                        inputs["input_ids"][0][answer_start:answer_end]))
                candidates = answer.split(" ")
                candidates[0] = candidates[0].capitalize()
                final = " ".join(candidates) + "."
                print(final)