def sample_simplax(probs):
     dist = RelaxedOneHotCategorical(probs=probs,
                                     temperature=torch.Tensor([1.]))
     z = dist.rsample()
     logprob = dist.log_prob(z)
     b = torch.argmax(z, dim=1)
     return z, b, logprob
Ejemplo n.º 2
0
    def get_loss():

        x = sample_true(batch_size).cuda()  #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits / 100., dim=1)
        cat = RelaxedOneHotCategorical(probs=probs,
                                       temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size, 1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(
            x,
            component=cluster_H,
            needsoftmax_mixtureweight=needsoftmax_mixtureweight,
            cuda=True)
        f = logpxz  # - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1)  #[B,21]
        surr_pred = surrogate.net(surr_input)

        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        return surr_loss
Ejemplo n.º 3
0
def show_surr_preds():

    batch_size = 1

    rows = 3
    cols = 1
    fig = plt.figure(figsize=(10 + cols, 4 + rows),
                     facecolor='white')  #, dpi=150)

    for i in range(rows):

        x = sample_true(1).cuda()  #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits / 100., dim=1)
        cat = RelaxedOneHotCategorical(probs=probs,
                                       temperature=torch.tensor([1.]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size, 1)
        check_nan(logprob_cluster)

        z = cluster_S

        n_evals = 40
        x1 = np.linspace(-9, 205, n_evals)
        x = torch.from_numpy(x1).view(n_evals, 1).float().cuda()
        z = z.repeat(n_evals, 1)
        cluster_H = cluster_H.repeat(n_evals, 1)
        xz = torch.cat([z, x], dim=1)

        logpxz = logprob_undercomponent(
            x,
            component=cluster_H,
            needsoftmax_mixtureweight=needsoftmax_mixtureweight,
            cuda=True)
        f = logpxz  #- logprob_cluster

        surr_pred = surrogate.net(xz)
        surr_pred = surr_pred.data.cpu().numpy()
        f = f.data.cpu().numpy()

        col = 0
        row = i
        # print (row)
        ax = plt.subplot2grid((rows, cols), (row, col),
                              frameon=False,
                              colspan=1,
                              rowspan=1)

        ax.plot(x1, surr_pred, label='Surr')
        ax.plot(x1, f, label='f')
        ax.set_title(str(cluster_H[0]))
        ax.legend()

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir + 'gmm_surr.png'
    plt.savefig(plt_path)
    print('saved training plot', plt_path)
    plt.close()
Ejemplo n.º 4
0
def simplax(surrogate, x, logits, mixtureweights, k=1):

    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())

    outputs = {}
    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq = cat.log_prob(cluster_S.detach()).view(B,1)
        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq - 1.

        surr_input = torch.cat([cluster_S, x, logits], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)

        net_loss += - torch.mean((f.detach() - surr_pred.detach()) * logq  + surr_pred)


        # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred))
        # grad_logq =  torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        grad_surr =  torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq + grad_surr)**2)

        surr_dif = torch.mean(torch.abs(f.detach() - surr_pred))
        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))
   
    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    outputs['net_loss'] = net_loss
    outputs['f'] = f
    outputs['logpx_given_z'] = logpx_given_z
    outputs['logpz'] = logpz
    outputs['logq'] = logq
    outputs['surr_loss'] = surr_loss
    outputs['surr_dif'] = surr_dif   
    outputs['grad_path'] = grad_path   
    outputs['grad_score'] = grad_score   

    return outputs #net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
Ejemplo n.º 5
0
 def forward(self, x):
     x = self.featCompressor(x)
     x = self.fc1(x)
     x = self.fc2(x)
     logits = self.fc3(x)
     B, L, K = logits.shape
     RelaxedOneHotSampler = RelaxedOneHotCategorical(float(self.temper),
                                                     logits=logits)
     y = RelaxedOneHotSampler.rsample()
     return y, F.softmax(logits, dim=-1), logits
Ejemplo n.º 6
0
def simplax(surrogate, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs,
                                   temperature=torch.tensor([1.]).cuda())

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq = cat.log_prob(cluster_S.detach()).view(B, 1)
        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B, 1)
        logpxz = logpx_given_z + logpz  #[B,1]
        f = logpxz - logq - 1.

        surr_input = torch.cat([cluster_S, x], dim=1)  #[B,21]
        surr_pred = surrogate.net(surr_input)

        net_loss += -torch.mean((f.detach() - surr_pred.detach()) * logq +
                                surr_pred)

        # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred))
        # grad_logq =  torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_logq = torch.autograd.grad([torch.mean(logq)], [logits],
                                        create_graph=True,
                                        retain_graph=True)[0]
        grad_surr = torch.autograd.grad([torch.mean(surr_pred)], [logits],
                                        create_graph=True,
                                        retain_graph=True)[0]
        surr_loss += torch.mean(
            ((f.detach() - surr_pred) * grad_logq + grad_surr)**2)

        surr_dif = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits],
                                        create_graph=True,
                                        retain_graph=True)[0]
        grad_score = torch.autograd.grad(
            [torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits],
            create_graph=True,
            retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))

    net_loss = net_loss / k
    surr_loss = surr_loss / k

    return net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
Ejemplo n.º 7
0
 def gumbel_softmax_dist(self,
                         param,
                         name,
                         temperature=1e-1,
                         hard=True,
                         sample_size=()):
     """ST gumbel with pytorch distributions."""
     gumbel = RelaxedOneHotCategorical(temperature, logits=param)
     y = gumbel.rsample(sample_size)
     if hard:
         # One-hot the y
         y = self.st_op(y)
     return y
Ejemplo n.º 8
0
def show_surr_preds():

    batch_size = 1

    rows = 3
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    for i in range(rows):

        x = sample_true(1).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        z = cluster_S

        n_evals = 40
        x1 = np.linspace(-9,205, n_evals)
        x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
        z = z.repeat(n_evals,1)
        cluster_H = cluster_H.repeat(n_evals,1)
        xz = torch.cat([z,x], dim=1) 

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_pred = surrogate.net(xz)
        surr_pred = surr_pred.data.cpu().numpy()
        f = f.data.cpu().numpy()

        col =0
        row = i
        # print (row)
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

        ax.plot(x1,surr_pred, label='Surr')
        ax.plot(x1,f, label='f')
        ax.set_title(str(cluster_H[0]))
        ax.legend()


    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_surr.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
Ejemplo n.º 9
0
    def reparameterize(self, p_i, tau, k, num_sample=1):

        ## sampling
        p_i_ = p_i.view(p_i.size(0), 1, 1, -1)
        p_i_ = p_i_.expand(p_i_.size(0), num_sample, k, p_i_.size(-1))
        C_dist = RelaxedOneHotCategorical(tau, p_i_)
        V = torch.max(C_dist.sample(), -2)[0]  # [batch-size, multi-shot, d]

        ## without sampling
        V_fixed_size = p_i.unsqueeze(1).size()
        _, V_fixed_idx = p_i.unsqueeze(1).topk(k, dim=-1)  # batch * 1 * k
        V_fixed = idxtobool(V_fixed_idx, V_fixed_size, is_cuda=self.args.cuda)
        V_fixed = V_fixed.type(torch.float)

        return V, V_fixed
Ejemplo n.º 10
0
def _gumbel_softmax(probs, tau: float, hard: bool):
    """ Computes sampling from the Gumbel Softmax (GS) distribution
    Args:
        probs (torch.tensor): probabilities of shape [batch_size, n_classes] 
        tau (float): temperature parameter for the GS
        hard (bool): discretize if True
    """

    rohc = RelaxedOneHotCategorical(tau, probs)
    y = rohc.rsample()

    if hard:
        y_hard = torch.zeros_like(y)
        y_hard.scatter_(-1, torch.argmax(y, dim=-1, keepdim=True), 1.0)
        y = (y_hard - y).detach() + y

    return y
Ejemplo n.º 11
0
 def forward(self, inputs, lengths, temp=None, y=None):
     enc_emb = self.lookup(inputs)
     dec_emb = self.lookup(inputs)
     hn = self.encoder(enc_emb, lengths)
     py = self.classifier(hn)
     if y is None:
         dist = RelaxedOneHotCategorical(temp, logits=py)
         y = dist.sample().max(1)[1]
     y_emb = self.y_lookup(y)
     h = torch.cat([hn, y_emb.unsqueeze(0)], dim=2)
     mu, logvar = self.fcmu(h), self.fclogvar(h)
     if self.training:
         z = self.reparameterize(mu, logvar)
     else:
         z = mu
     code = torch.cat([z, y_emb.unsqueeze(0)], dim=2)
     outputs, _ = self.decoder(dec_emb, code, lengths=lengths)
     outputs = self.fcout(outputs)
     bow = self.bow_predictor(code)
     return outputs, mu, logvar, bow, py
Ejemplo n.º 12
0
    def forward(self, y, c, m, useGumbel=True, temp=0.5): # TODO: add m to all things calling this.
        # y is the one-hot input (goal is to predict next output)
        # c is one-hot representing context
        # useGumbel=True, then uses Gumbel instead of softmax

        # embed in N-dim
        x_context = torch.matmul(self.Wfix_context, c)
        x_state = torch.matmul(self.Wfix, y)

        # concatenate token and context
        x = torch.cat([x_context, x_state, m]) # TODO: confirm works

        # hidden node
        b = torch.tanh(self.fc1(x))

        # ---- STATE
        if useGumbel: # will not get gradients for fc3 if do this.
            h = torch.tanh(self.fc2(b)) # linear layer + nonlinearity
            z = Gumbel(torch.tensor([0.0]), torch.tensor([1.0])).sample(torch.Size((h.shape[0],))) # add gumbel noise
            yind = (h.view((-1,)) + z.view((-1,))).argmax() # take argmax
            yout = self.idx_to_onehot(yind, self.outdim) # convert to onehot

            # ---- CONTEXT
            h_context = torch.tanh(self.fc3(b))
            z_context = Gumbel(torch.tensor([0.0]), torch.tensor([1.0])).sample(torch.Size((h_context.shape[0],))) # add gumbel noise
            c_ind = (h_context.view((-1,)) + z_context.view((-1,))).argmax() # take argmax
            c_out = self.idx_to_onehot(c_ind, self.Kc)
        else:
            h = torch.tanh(self.fc2(b)) # linear layer + nonlinearity
            yout = RelaxedOneHotCategorical(temp, logits=h).rsample()
            # yout = self.idx_to_onehot(yind, self.outdim) # convert to onehot
            yind = []

            # ---- CONTEXT
            h_context = torch.tanh(self.fc3(b))
            c_out = RelaxedOneHotCategorical(temp, logits=h_context).rsample()
            c_ind = []

        m = b # rename as m
        return yout, h, yind, c_out, h_context, c_ind, m
Ejemplo n.º 13
0
def reinforce_pz(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())

    net_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logq = cat.log_prob(cluster_S.detach()).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq - 1.
        net_loss += - torch.mean((f.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
Ejemplo n.º 14
0
 def gumbel_softmax(self, logits, training, tau=1.0, msg_hard=None):
     device = torch.device("cuda" if logits.is_cuda else "cpu")
     
     if training:
         # Here, Gumbel sample is taken:
         msg_dists = RelaxedOneHotCategorical(tau, logits=logits)
         msg       = msg_dists.rsample()
         
         if msg_hard is None:
             msg_hard = torch.zeros_like(msg, device=device)
             msg_hard.scatter_(-1, torch.argmax(msg, dim=-1, keepdim=True), 1.0)
         
         # detach() detaches the output from the computation graph, so no gradient will be backprop'ed along this variable
         msg = (msg_hard - msg).detach() + msg
     
     else:
         if msg_hard is None:
             msg = torch.zeros_like(logits, device=self.device)
             msg.scatter_(-1, torch.argmax(logits, dim=-1, keepdim=True), 1.0)
         else:
             msg = msg_hard
     
     return msg
Ejemplo n.º 15
0
def reinforce_pz(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs,
                                   temperature=torch.tensor([1.]).cuda())

    net_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logq = cat.log_prob(cluster_S.detach()).view(B, 1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B, 1)
        logpxz = logpx_given_z + logpz  #[B,1]
        f = logpxz - logq - 1.
        net_loss += -torch.mean((f.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss / k

    return net_loss, f, logpx_given_z, logpz, logq
Ejemplo n.º 16
0
    def get_loss():

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)
        
        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(f.detach()-surr_pred))

        return surr_loss
Ejemplo n.º 17
0
def cat_softmax(probs, mode, tau=1, hard=False, dim=-1):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(probs=probs)
        return cat_distr.sample(), cat_distr.entropy()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(tau, probs=probs)
        y_soft = cat_distr.rsample()

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(probs,
                                  device=DEVICE).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret, ret
Ejemplo n.º 18
0
def decoding_sampler(logits, mode, tau=1, hard=False, dim=-1):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(logits=logits)
        return cat_distr.sample()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(tau, logits=logits)
        y_soft = cat_distr.rsample()
    elif mode == 'SOFTMAX':
        y_soft = F.softmax(logits, dim=1)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, device=args.device).scatter_(
            dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft

    return ret
Ejemplo n.º 19
0
    def reparameterize(self, p_pep, p_tcr, tau, k, num_sample):

        # sampling
        batch_size = p_pep.size(0)
        len_pep = p_pep.size(1) # batch_size * len_pep
        len_tcr = p_tcr.size(1) # batch_size * len_tcr
        p_pep_ = p_pep.view(batch_size, 1, 1, len_pep).expand(batch_size, num_sample, k, len_pep)
        p_tcr_ = p_tcr.view(batch_size, 1, 1, len_tcr).expand(batch_size, num_sample, k, len_tcr)
        C_pep = RelaxedOneHotCategorical(tau, p_pep_)
        C_tcr = RelaxedOneHotCategorical(tau, p_tcr_)
        Z_pep, _ = torch.max(C_pep.sample(), -2) # batch_size, num_sample, len_pep
        Z_tcr, _ = torch.max(C_tcr.sample(), -2) # batch_size, num_sample, len_tcr
        
        # without sampling
        _, Z_fixed_pep = p_pep.topk(k, dim = -1) # batch_size, k
        _, Z_fixed_tcr = p_tcr.topk(k, dim = -1) # batch_size, k
        size_pep = p_pep.size()
        size_tcr = p_tcr.size()
        Z_fixed_pep = idxtobool(Z_fixed_pep, size_pep, self.cuda)
        Z_fixed_tcr = idxtobool(Z_fixed_tcr, size_tcr, self.cuda)

        return Z_pep, Z_tcr, Z_fixed_pep, Z_fixed_tcr
Ejemplo n.º 20
0
    def sample_lambda0_r(self,
                         d,
                         batch_size,
                         offset=0,
                         object_locations=None,
                         object_margin=None,
                         num_objects=None,
                         gau=None,
                         max_rejections=1000,
                         margin_offset=2):
        """Sample dataset parameters perturbed by r."""
        name = d['name']
        family = d['family']
        attr_name = '{}_{}'.format(name, 'center')
        if self.wn:
            lambda_r = self.normalize_weights(name=name, prop='center')
        elif family != 'half_normal':
            lambda_r = getattr(self, attr_name)
        parameters = []
        if family == 'gaussian':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            # lambda_r = transform_to(constraints.greater_than(
            #     1.))(lambda_r)
            # lambda_r_scale = transform_to(constraints.greater_than(
            #     self.minimum_spatial_scale))(lambda_r_scale)
            # TODO: Add constraint function here
            # w=module.weight.data
            # w=w.clamp(0.5,0.7)
            # module.weight.data=w

            if gau is None:
                gau = MultivariateNormal(loc=lambda_r,
                                         covariance_matrix=lambda_r_scale)
            if d['return_sampler']:
                return gau
            if name == 'object_location':
                if not len(object_locations):
                    return gau.rsample(), gau
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=gau)
            else:
                raise NotImplementedError(name)
        elif family == 'normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            nor = Normal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            elif name == 'object_location':
                # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale)  # noqa
                if not len(object_locations):
                    return nor.rsample(), nor
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=nor)
            else:
                for idx in range(batch_size):
                    parameters.append(nor.rsample())
        elif family == 'cnormal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)

            # Explicitly clamp the scale!
            lambda_r_scale = torch.clamp(lambda_r_scale,
                                         self.minimum_spatial_scale, 999.)
            nor = CNormal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            elif name == 'object_location':
                # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale)  # noqa
                if not len(object_locations):
                    return nor.rsample(), nor
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=nor)
            else:
                for idx in range(batch_size):
                    parameters.append(nor.rsample())
        elif family == 'abs_normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            # lambda_r = transform_to(Normal.arg_constraints['loc'])(lambda_r)
            # lambda_r_scale = transform_to(Normal.arg_constraints['scale'])(lambda_r_scale)  # noqa
            # lambda_r = transforms.AbsTransform()(lambda_r)
            # lambda_r_scale = transforms.AbsTransform()(lambda_r_scale)
            # These kill grads!! # lambda_r = torch.abs(lambda_r)
            # These kill grads!! lambda_r_scale = torch.abs(lambda_r_scale)
            nor = Normal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            else:
                parameters = nor.rsample([batch_size])
        elif family == 'half_normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            nor = HalfNormal(scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            else:
                parameters = nor.rsample([batch_size])
        elif family == 'categorical':
            if d['return_sampler']:
                gum = RelaxedOneHotCategorical(1e-1, logits=lambda_r)
                return gum
                # return lambda sample_size: self.argmax(self.gumbel_fun(lambda_r, name=name)) + offset   # noqa
            for _ in range(batch_size):
                parameters.append(
                    self.argmax(self.gumbel_fun(lambda_r, name=name)) +
                    offset)  # noqa Use default temperature -> max
        elif family == 'relaxed_bernoulli':
            bern = RelaxedBernoulli(temperature=1e-1, logits=lambda_r)
            if d['return_sampler']:
                return bern
            else:
                parameters = bern.rsample([batch_size])
        else:
            raise NotImplementedError(
                '{} not implemented in sampling.'.format(family))
        return parameters
surrogate = NN3(input_size=C, output_size=1, n_residual_blocks=2)

train_ = 1
n_steps = 1000  #0 #0 #1000 #50000 #
B = 1  #32 #0
k = 3
if train_:
    optim = torch.optim.Adam(surrogate.parameters(),
                             lr=1e-4,
                             weight_decay=1e-7)
    #Train surrogate
    for i in range(n_steps + 1):
        warmup = 1.

        cat = RelaxedOneHotCategorical(logits=logits.repeat(B, 1),
                                       temperature=torch.tensor([1.]))
        z = cat.rsample()
        logprob = cat.log_prob(z.detach()).view(B, 1)
        b = H(z)
        reward = f(b).view(B, 1)
        cz = surrogate.net(z)

        # estimator = (reward - cz) * logprob + cz
        # grad = torch.autograd.grad([torch.mean(estimator)], [logits], create_graph=True, retain_graph=True)[0]

        gradlogprob = torch.autograd.grad([torch.mean(logprob)], [logits],
                                          create_graph=True,
                                          retain_graph=True)[0]
        gradcz = torch.autograd.grad([torch.mean(cz)], [logits],
                                     create_graph=True,
                                     retain_graph=True)[0]
Ejemplo n.º 22
0
    def forward(self, t, word_counts, tau=1.2):
        batch_size = t.shape[0]

        if self.training:
            message = [
                torch.zeros((batch_size, self.vocab_size), dtype=torch.float32)
            ]
            if self.use_gpu:
                message[0] = message[0].cuda()
            message[0][:, self.bound_token_idx] = 1.0
        else:
            message = [
                torch.full((batch_size, ),
                           fill_value=self.bound_token_idx,
                           dtype=torch.int64)
            ]
            if self.use_gpu:
                message[0] = message[0].cuda()

        # h0, c0
        h = self.aff_transform(t)  # batch_size, hidden_size
        c = torch.zeros([batch_size, self.hidden_size])

        initial_length = self.max_sentence_length + 1
        seq_lengths = torch.ones([batch_size],
                                 dtype=torch.int64) * initial_length

        ce_loss = nn.CrossEntropyLoss(reduction='none')

        # Handle alpha by giving weight to the padding token
        w_counts = word_counts.clone()  # Tensor is passed by ref
        w_counts[self.bound_token_idx] *= self.bound_weight

        denominator = w_counts.sum()
        if denominator > 0:
            normalized_word_counts = w_counts / denominator
        else:
            normalized_word_counts = w_counts

        vl_loss = 0.0
        entropy = 0.0

        if self.use_gpu:
            c = c.cuda()
            seq_lengths = seq_lengths.cuda()

        input_embed_rep = []

        for i in range(self.max_sentence_length
                       ):  # or sampled <EOS>, but this is batched
            emb = torch.matmul(
                message[-1], self.embedding
            ) if self.training else self.embedding[message[-1]]
            h, c = self.lstm_cell(emb, (h, c))

            vocab_scores = self.linear_probs(h)
            p = F.softmax(vocab_scores, dim=1)
            entropy += Categorical(p).entropy()

            if self.training:
                rohc = RelaxedOneHotCategorical(tau, p)
                token = rohc.rsample()

                # Straight-through part
                token_hard = torch.zeros_like(token)
                token_hard.scatter_(-1,
                                    torch.argmax(token, dim=-1, keepdim=True),
                                    1.0)
                token = (token_hard - token).detach() + token
            else:
                if self.greedy:
                    _, token = torch.max(p, -1)
                else:
                    token = Categorical(p).sample()

            message.append(token)
            input_embed_rep.append(emb)

            self._calculate_seq_len(seq_lengths,
                                    token,
                                    initial_length,
                                    seq_pos=i + 1)

            if self.vl_loss_weight > 0.0:
                vl_loss += ce_loss(vocab_scores - normalized_word_counts,
                                   self._discretize_token(token))

        return (torch.stack(message, dim=1), seq_lengths, vl_loss,
                torch.mean(entropy) / self.max_sentence_length,
                torch.stack(input_embed_rep, dim=1))
cols = 1
fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)


col =0
row = 0
ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

n_cats = 2

# needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.tensor([], requires_grad=True)
# weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
theta = .99
weights =  torch.tensor([1-theta,theta], requires_grad=True).float()
cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.]))

val = 1.
val2 = 0
val3 = 0
cmap='Blues'
alpha =1.
xlimits=[val3, val]
ylimits=[val2, val]
numticks = 51
x = np.linspace(*xlimits, num=numticks)
y = np.linspace(*ylimits, num=numticks)
X, Y = np.meshgrid(x, y)
aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
aaa = torch.tensor(aaa).float()
logprob = cat.log_prob(aaa)
Ejemplo n.º 24
0
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
weights = torch.softmax(needsoftmax_mixtureweight, dim=0)



# cat = Categorical(probs= weights)
avg_samp = torch.zeros(n_cats)
grads = []
logprobgrads = []
momem = 0
for step in range(max_steps):

    weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
    cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.]))
    cluster_S = cat.sample()
    logprob = cat.log_prob(cluster_S.detach())
    cluster_H = H(cluster_S) #

    logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(needsoftmax_mixtureweight), retain_graph=False)[0]

    one_hot = torch.zeros(n_cats)
    one_hot[cluster_H] = 1.

    f_val = f(one_hot)

    # grad = f_val * logprobgrad
    # needsoftmax_mixtureweight = needsoftmax_mixtureweight + lr*grad

    grad = f_val * logprobgrad
Ejemplo n.º 25
0
    def forward(self, image, question=None, feature_saving=False, cut_image_info=False, cut_question_info=False,
                ground_truth=None, save_messages=False, save_features=False):

        self.use_gpu = torch.cuda.is_available()
        if self.use_gpu:
            device = 'cuda'
        else:
            device = 'cpu'
        assert not (feature_saving and self.training)
        self.batch_size = image.shape[0]
        im_features = []

        if not self.train_from_symbolic:
            for i in range(image.shape[1]):
                im_features.append(self.stem_conv(image[:, i, :, :, :]).view(self.batch_size, -1, 1))
            image_info = im_features[0]
            candidate_im_features = im_features[1:]

        else:
            candidate_im_features = []
            ground_truth = ground_truth.type(torch.FloatTensor).to(device)
            image_info = ground_truth[:, 0, :]
            for i in range(image.shape[1]-1):
                candidate_im_features.append(self.bottleneck_in_fc(ground_truth[:, i+1, :]).view(self.batch_size, -1, 1))



        #detaching gradients
        candidate_im_features = [candidate.detach() for candidate in candidate_im_features]

        if self.bottleneck:

            if self.training:
                message = [torch.zeros((self.batch_size, self.vocab_size), dtype=torch.float32)]
                if self.use_gpu:
                    message[0] = message[0].cuda()
                message[0][:, self.bound_token_idx] = 1.0
            else:
                message = [torch.full((self.batch_size,), fill_value=self.bound_token_idx, dtype=torch.int64)]
                if self.use_gpu:
                    message[0] = message[0].cuda()

            # h0, c0
            flattened_im_features = image_info.view(self.batch_size, -1)
            sender_representation = self.drop(self.bottleneck_in_fc(flattened_im_features))
            h = sender_representation
            c = torch.zeros((self.batch_size, self.message_lstm_hidden_size))

            if self.use_gpu:
                c = c.cuda()

            entropy = 0.0

            # produce words one by one
            for i in range(self.max_sentence_length):

                emb = torch.matmul(message[-1], self.message_embedding) if self.training else self.message_embedding[message[-1]]

                h, c = self.lstm_cell(emb, (h, c))

                vocab_scores = self.drop(self.hidden2vocab(h))
                p = F.softmax(vocab_scores, dim=1)
                entropy += Categorical(p).entropy()

                if self.training:

                    rohc = RelaxedOneHotCategorical(self.tau, p)
                    token = rohc.rsample()

                    # Straight-through part
                    if not self.continuous_communication:
                        token_hard = torch.zeros_like(token)
                        token_hard.scatter_(-1, torch.argmax(token, dim=-1, keepdim=True), 1.0)
                        token = (token_hard - token).detach() + token

                else:
                    if self.greedy:
                        _, token = torch.max(p, -1)
                    else:
                        token = Categorical(p).sample()

                message.append(token)
            message = torch.stack(message, dim=1)

            if self.training:
                _, m = torch.max(message, dim=-1)
            else:
                m = message

            md = calc_message_distinctness(m)


            # If we feed the ground_truth to the receiver, we simply hijack the message here
            if self.use_ground_truth:
                if self.training:
                    message = batch_2_onehot(ground_truth, self.max_sentence_length, self.vocab_size)
                else:
                    message = ground_truth.type(torch.LongTensor)
                if self.use_gpu:
                    message = message.cuda()


            # Receiver part

            h = torch.zeros((self.batch_size, self.encoder_lstm_hidden_size))
            c = torch.zeros((self.batch_size, self.encoder_lstm_hidden_size))

            if self.use_gpu:
                h = h.cuda()
                c = c.cuda()

            emb = torch.matmul(message, self.message_embedding) if self.training else self.message_embedding[message]
            _, (h, c) = self.message_encoder_lstm(emb, (h[None, ...], c[None, ...]))
            hidden_receiver = h[0]
            bottleneck_out = self.drop(self.aff_transform(hidden_receiver))

            image_info = bottleneck_out

            #todo recomment
            # comm_info = {'entropy': torch.mean(entropy).item() / (self.max_sentence_length+ 1e-7), 'md': md}
            comm_info = {'entropy': 0, 'md': 0}
            if save_features:
                comm_info['image_features'] = flattened_im_features
                comm_info['sender_repr'] = sender_representation
                comm_info['receiver_repr'] = hidden_receiver
            if save_messages:
                comm_info['message'] = m

        # in case no communication bottleneck
        else:
            comm_info = None

        orig_im_features = image_info.view(self.batch_size, 1, -1)
        out = torch.zeros(self.batch_size, len(candidate_im_features)).type(torch.FloatTensor)
        if self.use_gpu:
            out = out.cuda()
        for i in range(len(candidate_im_features)):
            out[:, i] = torch.bmm(orig_im_features, candidate_im_features[i]).squeeze()

        return out, comm_info
Ejemplo n.º 26
0
        surr_loss = get_loss()
        optim_surr.zero_grad()
        surr_loss.backward()
        optim_surr.step()
        if ii%1000==0:
            print (ii, surr_loss)


    for step in range(n_steps):

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits/100., dim=1)
        # print (probs)
        # fsdafsa
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())

        net_loss = 0
        loss = 0
        surr_loss = 0
        for jj in range(k):

            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.
            logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            check_nan(logprob_cluster)

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz #- logprob_cluster
Ejemplo n.º 27
0
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs,
                                   temperature=torch.tensor([1.]).cuda())
    cat_bernoulli = Categorical(probs=probs)

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq_z = cat.log_prob(cluster_S.detach()).view(B, 1)
        logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B, 1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B, 1)
        logpxz = logpx_given_z + logpz  #[B,1]
        f_z = logpxz - logq_z - 1.
        f_b = logpxz - logq_b - 1.

        surr_input = torch.cat([cluster_S, x], dim=1)  #[B,21]
        # surr_pred, alpha = surrogate.net(surr_input)
        surr_pred = surrogate.net(surr_input)
        alpha = torch.sigmoid(surrogate2.net(x))

        net_loss += -torch.mean(alpha.detach() *
                                (f_z.detach() - surr_pred.detach()) * logq_z +
                                alpha.detach() * surr_pred +
                                (1 - alpha.detach()) * (f_b.detach()) * logq_b)

        # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred))

        grad_logq_z = torch.mean(torch.autograd.grad([torch.mean(logq_z)],
                                                     [logits],
                                                     create_graph=True,
                                                     retain_graph=True)[0],
                                 dim=1,
                                 keepdim=True)
        grad_logq_b = torch.mean(torch.autograd.grad([torch.mean(logq_b)],
                                                     [logits],
                                                     create_graph=True,
                                                     retain_graph=True)[0],
                                 dim=1,
                                 keepdim=True)
        grad_surr = torch.mean(torch.autograd.grad([torch.mean(surr_pred)],
                                                   [logits],
                                                   create_graph=True,
                                                   retain_graph=True)[0],
                               dim=1,
                               keepdim=True)
        # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape)
        # fsdfa
        # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0]
        # print (grad_surr)
        # fsdfasd
        surr_loss += torch.mean(
            (alpha * (f_z.detach() - surr_pred) * grad_logq_z +
             alpha * grad_surr + (1 - alpha) *
             (f_b.detach()) * grad_logq_b)**2)

        surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred))
        # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0]
        # print (gradd)
        # fdsf
        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits],
                                        create_graph=True,
                                        retain_graph=True)[0]
        grad_score = torch.autograd.grad(
            [torch.mean(
                (f_z.detach() - surr_pred.detach()) * logq_z)], [logits],
            create_graph=True,
            retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))

    net_loss = net_loss / k
    surr_loss = surr_loss / k

    return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(
        alpha)
Ejemplo n.º 28
0
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
    cat_bernoulli = Categorical(probs=probs)

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq_z = cat.log_prob(cluster_S.detach()).view(B,1)
        logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1)


        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f_z = logpxz - logq_z - 1.
        f_b = logpxz - logq_b - 1.

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        # surr_pred, alpha = surrogate.net(surr_input)
        surr_pred = surrogate.net(surr_input)
        alpha = torch.sigmoid(surrogate2.net(x))

        net_loss += - torch.mean(     alpha.detach()*(f_z.detach()  - surr_pred.detach()) * logq_z  
                                    + alpha.detach()*surr_pred 
                                    + (1-alpha.detach())*(f_b.detach()  ) * logq_b)

        # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred))

        grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_logq_b =  torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape)
        # fsdfa
        # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0]
        # print (grad_surr)
        # fsdfasd
        surr_loss += torch.mean(
                                    (alpha*(f_z.detach() - surr_pred) * grad_logq_z 
                                    + alpha*grad_surr
                                    + (1-alpha)*(f_b.detach()) * grad_logq_b )**2
                                    )

        surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred))
        # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0]
        # print (gradd)
        # fdsf
        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))


    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
col = 0
row = 0
ax = plt.subplot2grid((rows, cols), (row, col),
                      frameon=False,
                      colspan=1,
                      rowspan=1)

n_cats = 2

# needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.tensor([], requires_grad=True)
# weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
theta = .99
weights = torch.tensor([1 - theta, theta], requires_grad=True).float()
cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.]))

val = 1.
val2 = 0
val3 = 0
cmap = 'Blues'
alpha = 1.
xlimits = [val3, val]
ylimits = [val2, val]
numticks = 51
x = np.linspace(*xlimits, num=numticks)
y = np.linspace(*ylimits, num=numticks)
X, Y = np.meshgrid(x, y)
aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
aaa = torch.tensor(aaa).float()
logprob = cat.log_prob(aaa)
Ejemplo n.º 30
0
    def forward(self,
                instruction=None,
                observation=None,
                memory=None,
                compute_message_probs=False,
                time=None):
        if not hasattr(self, 'random_corrector'):
            self.random_corrector = False
        if not hasattr(self, 'var_len'):
            self.var_len = False
        if not hasattr(self, 'script'):
            self.script = False

        if not self.script:
            memory_rnn_output, memory = self.forward_film(
                instruction=instruction,
                observation=observation,
                memory=memory)

            batch_size = instruction.size(0)
            correction_encodings = []

            entropy = 0.0

            lengths = np.array([self.corr_length] * batch_size)
            total_corr_loss = 0

            for i in range(self.corr_length):
                if i == 0:
                    # every message starts with a SOS token
                    decoder_input = torch.tensor([self.sos_id] * batch_size,
                                                 dtype=torch.long,
                                                 device=self.device)
                    decoder_input_embedded = self.word_embedding_corrector(
                        decoder_input).unsqueeze(1)
                    decoder_hidden = memory_rnn_output.unsqueeze(0)

                if self.random_corrector:
                    # randomize corrections
                    device = torch.device(
                        "cuda" if decoder_input_embedded.is_cuda else "cpu")
                    decoder_input_embedded = torch.randn(
                        decoder_input_embedded.size(), device=device)
                    decoder_hidden = torch.randn(decoder_hidden.size(),
                                                 device=device)

                rnn_output, decoder_hidden = self.decoder_rnn(
                    decoder_input_embedded, decoder_hidden)
                vocab_scores = self.out(rnn_output)
                vocab_probs = F.softmax(vocab_scores, dim=-1)

                entropy += Categorical(vocab_probs).entropy()

                tau = 1.0 / (self.tau_layer(decoder_hidden).squeeze(0) +
                             self.max_tau)
                tau = tau.expand(-1, self.num_embeddings).unsqueeze(1)

                if self.training:
                    # Apply Gumbel SM
                    cat_distr = RelaxedOneHotCategorical(tau, vocab_probs)
                    corr_weights = cat_distr.rsample()
                    corr_weights_hard = torch.zeros_like(corr_weights,
                                                         device=self.device)
                    corr_weights_hard.scatter_(
                        -1, torch.argmax(corr_weights, dim=-1, keepdim=True),
                        1.0)

                    # detach() detaches the output from the computation graph, so no gradient will be backprop'ed along this variable
                    corr_weights = (corr_weights_hard -
                                    corr_weights).detach() + corr_weights

                else:
                    # greedy sample
                    corr_weights = torch.zeros_like(vocab_probs,
                                                    device=self.device)
                    corr_weights.scatter_(
                        -1, torch.argmax(vocab_probs, dim=-1, keepdim=True),
                        1.0)

                if self.var_len:
                    # consider sequence done when eos receives highest value
                    max_idx = torch.argmax(corr_weights, dim=-1)
                    eos_batches = max_idx.data.eq(self.eos_id)
                    if eos_batches.dim() > 0:
                        eos_batches = eos_batches.cpu().view(-1).numpy()
                        update_idx = ((lengths > i) & eos_batches) != 0
                        lengths[update_idx] = i

                    # compute correction error through pseudo-target: sequence of eos symbols to encourage short messages
                    pseudo_target = torch.tensor(
                        [self.eos_id for j in range(batch_size)],
                        dtype=torch.long,
                        device=self.device)
                    loss = self.correction_loss(corr_weights.squeeze(1),
                                                pseudo_target)
                    total_corr_loss += loss

                correction_encodings += [corr_weights]
                decoder_input_embedded = torch.matmul(
                    corr_weights, self.word_embedding_corrector.weight)

            # one-hot vectors on forward, soft approximations on backward pass
            correction_encodings = torch.stack(correction_encodings,
                                               dim=1).squeeze(2)

            lengths = torch.tensor(lengths,
                                   dtype=torch.long,
                                   device=self.device)

            result = {
                'correction_encodings':
                correction_encodings,
                'correction_messages':
                self.decode_corrections(correction_encodings),
                'correction_entropy':
                torch.mean(entropy),
                'corrector_memory':
                memory,
                'correction_lengths':
                lengths,
                'correction_loss':
                total_corr_loss
            }

        else:
            # there is a script of pre-established guidance messages
            correction_messages = self.script[time]
            correction_encodings = self.encode_corrections(correction_messages)
            result = {
                'correction_encodings': correction_encodings,
                'correction_messages': correction_messages
            }

        return (result)
Ejemplo n.º 31
0
    def forward(self, obs, memory, instr_embedding=None, tau=1.2):

        # Calculating instruction embedding
        if self.use_instr and instr_embedding is None:
            if self.lang_model == 'gru':
                _, hidden = self.instr_rnn(self.word_embedding(obs.instr))
                instr_embedding = hidden[-1]

        #Calculating the image imedding
        x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3)
        if self.arch.startswith("expert_filmcnn"):
            image_embedding = self.image_conv(x)

            #Calculating FiLM_embedding from image and instruction embedding
            for controler in self.controllers:
                x = controler(image_embedding, instr_embedding)
            FiLM_embedding = F.relu(self.film_pool(x))
        else:
            FiLM_embedding = self.image_conv(x)

        FiLM_embedding = FiLM_embedding.reshape(FiLM_embedding.shape[0], -1)

        #Going through the memory layer
        if self.use_memory:
            hidden = (memory[:, :self.semi_memory_size],
                      memory[:, self.semi_memory_size:])
            hidden = self.memory_rnn(FiLM_embedding, hidden)
            embedding = hidden[0]
            memory = torch.cat(hidden, dim=1)
        else:
            embedding = x

        if self.use_instr and not "filmcnn" in self.arch:
            embedding = torch.cat((embedding, instr_embedding), dim=1)

        if hasattr(self, 'aux_info') and self.aux_info:
            extra_predictions = {
                info: self.extra_heads[info](embedding)
                for info in self.extra_heads
            }
        else:
            extra_predictions = dict()

        memory_rnn_output = embedding
        batch_size = memory_rnn_output.shape[0]

        message = []
        for i in range(self.message_length):

            if i == 0:
                decoder_input = torch.tensor([self.sos_id] * batch_size,
                                             dtype=torch.long,
                                             device=self.device)
                decoder_input_embedded = self.word_embedding_decoder(
                    decoder_input).unsqueeze(1)
                decoder_hidden = memory_rnn_output.unsqueeze(0)

            decoder_out, decoder_hidden = self.decoder_rnn(
                decoder_input_embedded, decoder_hidden)
            vocab_scores = self.hidden2word(decoder_out)
            vocab_probs = F.softmax(vocab_scores, -1)

            tau = 1.0 / (self.tau_layer(decoder_hidden).squeeze(0) +
                         self.max_tau)
            tau = tau.expand(-1, self.vocab_size).unsqueeze(1)

            if self.training:
                rohc = RelaxedOneHotCategorical(tau, vocab_probs)
                token = rohc.rsample()

                # Straight-through part
                token_hard = torch.zeros_like(token)
                token_hard.scatter_(-1,
                                    torch.argmax(token, dim=-1, keepdim=True),
                                    1.0)
                token = (token_hard - token).detach() + token
            else:
                token = torch.zeros_like(vocab_probs, device=self.device)
                token.scatter_(-1,
                               torch.argmax(vocab_probs, dim=-1, keepdim=True),
                               1.0)

            message.append(token)

            decoder_input_embedded = torch.matmul(
                token, self.word_embedding_decoder.weight)
        comm = torch.stack(message, dim=1).squeeze(2)
        return comm, memory
print('Grad mean', np.mean(grads, axis=0))
print('Grad std', np.std(grads, axis=0))
print('Avg logprobgrad', np.mean(logprobgrads, axis=0))
print('Std logprobgrad', np.std(logprobgrads, axis=0))
print()

#REINFORCE P(Z)

# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats),
                                         requires_grad=True).float()
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

# cat = Categorical(probs= weights)
cat = RelaxedOneHotCategorical(probs=weights,
                               temperature=torch.tensor([1.]))  #.cuda())

# dist = Bernoulli(bern_param)
# samps = []
avg_samp = torch.zeros(n_cats)
gradspz = []
logprobgrads = []
for i in range(n):
    # samp = dist.sample()
    cluster_S = cat.sample()
    logprob = cat.log_prob(cluster_S.detach())
    cluster_H = H(cluster_S)  #
    one_hot = torch.zeros(n_cats)
    one_hot[cluster_H] = 1.

    # logprob = dist.log_prob(samp.detach())
Ejemplo n.º 33
0
            optim_surr.step()

        # x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)

        probs2 = probs.cpu().data.numpy()[0]
        print()
        for iii in range(len(probs2)):
            print(str(iii)+':'+str(probs2[iii]), end =" ")
        print ()

        # print (probs.shape)
        print (probs)
        # fsdf
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())

        net_loss = 0
        loss = 0
        surr_loss = 0
        grads = 0
        for jj in range(k):

            cluster_S = cat.rsample()
            # print (cluster_S.shape)
            # dfsafd

            # cluster_onehot = torch.zeros(n_components).cuda()
            # cluster_onehot[jj%20] =1.
            # cluster_S = torch.softmax(cluster_onehot, dim=0).view(1,20)
Ejemplo n.º 34
0
    for step in range(n_steps):

        optim.zero_grad()

        loss = 0
        net_loss = 0
        surr_loss = 0
        for i in range(batch_size):
            x = sample_true().cuda().view(1, 1)
            logits = encoder.net(x)
            # print (logits.shape)
            # print (torch.softmax(logits, dim=1))
            # fasd
            # cat = Categorical(probs= torch.softmax(logits, dim=0))
            cat = RelaxedOneHotCategorical(probs=torch.softmax(logits, dim=1),
                                           temperature=torch.tensor([1.
                                                                     ]).cuda())
            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.
            # print (cluster_onehot)
            # print (cluster_H)
            # print (cluster_S)
            # fdsa
            logprob_cluster = cat.log_prob(cluster_S.detach())
            if logprob_cluster != logprob_cluster:
                print('nan')
            # print (logprob_cluster)
            pxz = logprob_undercomponent(
                x,
def H(soft):
    return torch.argmax(soft, dim=1)

surrogate = NN3(input_size=C, output_size=1, n_residual_blocks=2)

train_ = 1
n_steps = 1000#0 #0 #1000 #50000 #
B = 1 #32 #0
k=3
if train_:
    optim = torch.optim.Adam(surrogate.parameters(), lr=1e-4, weight_decay=1e-7)
    #Train surrogate
    for i in range(n_steps+1):
        warmup = 1.

        cat = RelaxedOneHotCategorical(logits=logits.repeat(B,1), temperature=torch.tensor([1.]))
        z = cat.rsample()
        logprob = cat.log_prob(z.detach()).view(B,1)
        b = H(z)
        reward = f(b).view(B,1) 
        cz = surrogate.net(z)

        # estimator = (reward - cz) * logprob + cz
        # grad = torch.autograd.grad([torch.mean(estimator)], [logits], create_graph=True, retain_graph=True)[0]

        gradlogprob =  torch.autograd.grad([torch.mean(logprob)], [logits], create_graph=True, retain_graph=True)[0]
        gradcz =  torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]

        # print (reward.shape, cz.shape, gradlogprob.shape, gradcz.shape)
        # fdasf
        # grad = (reward-cz) *gradlogprob + gradcz
 def sample_simplax(probs):
     dist = RelaxedOneHotCategorical(probs=probs, temperature=torch.Tensor([1.]))
     z = dist.rsample()
     logprob = dist.log_prob(z)
     b = torch.argmax(z, dim=1)
     return z, b, logprob
Ejemplo n.º 37
0
lr = .002

# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True)
weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

# cat = Categorical(probs= weights)
avg_samp = torch.zeros(n_cats)
grads = []
logprobgrads = []
momem = 0
for step in range(max_steps):

    weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
    cat = RelaxedOneHotCategorical(probs=weights,
                                   temperature=torch.tensor([1.]))
    cluster_S = cat.sample()
    logprob = cat.log_prob(cluster_S.detach())
    cluster_H = H(cluster_S)  #

    logprobgrad = torch.autograd.grad(outputs=logprob,
                                      inputs=(needsoftmax_mixtureweight),
                                      retain_graph=False)[0]

    one_hot = torch.zeros(n_cats)
    one_hot[cluster_H] = 1.

    f_val = f(one_hot)

    # grad = f_val * logprobgrad
    # needsoftmax_mixtureweight = needsoftmax_mixtureweight + lr*grad