示例#1
0
文件: models.py 项目: mukami12/REM
    def forward(self, x, S):
        x = x.view(-1, self.x_dim)
        bsz = x.size(0)

        ### get w and \alpha and L(\theta)
        mu, logvar = self.encoder(x)
        q_phi = Normal(loc=mu, scale=torch.exp(0.5 * logvar))
        z_q = q_phi.rsample((S, ))
        recon_batch = self.decoder(z_q)
        x_dist = Bernoulli(logits=recon_batch)
        log_lik = x_dist.log_prob(x).sum(-1)
        log_prior = self.prior.log_prob(z_q).sum(-1)
        log_q = q_phi.log_prob(z_q).sum(-1)
        log_w = log_lik + log_prior - log_q
        tmp_alpha = torch.logsumexp(log_w, dim=0).unsqueeze(0)
        alpha = torch.exp(log_w - tmp_alpha).detach()
        if self.version == 'v1':
            p_loss = -alpha * (log_lik + log_prior)

        ### get moment-matched proposal
        mu_r = alpha.unsqueeze(2) * z_q
        mu_r = mu_r.sum(0).detach()
        z_minus_mu_r = z_q - mu_r.unsqueeze(0)
        reshaped_diff = z_minus_mu_r.view(S * bsz, -1, 1)
        reshaped_diff_t = reshaped_diff.permute(0, 2, 1)
        outer = torch.bmm(reshaped_diff, reshaped_diff_t)
        outer = outer.view(S, bsz, self.z_dim, self.z_dim)
        Sigma_r = outer.mean(0) * S / (S - 1)
        Sigma_r = Sigma_r + torch.eye(self.z_dim).to(device) * 1e-6  ## ridging

        ### get v, \beta, and L(\phi)
        L = torch.cholesky(Sigma_r)
        r_phi = MultivariateNormal(loc=mu_r, scale_tril=L)

        z = r_phi.rsample((S, ))
        z_r = z.detach()
        recon_batch_r = self.decoder(z_r)
        x_dist_r = Bernoulli(logits=recon_batch_r)
        log_lik_r = x_dist_r.log_prob(x).sum(-1)
        log_prior_r = self.prior.log_prob(z_r).sum(-1)
        log_r = r_phi.log_prob(z_r)
        log_v = log_lik_r + log_prior_r - log_r
        tmp_beta = torch.logsumexp(log_v, dim=0).unsqueeze(0)
        beta = torch.exp(log_v - tmp_beta).detach()
        log_q = q_phi.log_prob(z_r).sum(-1)
        q_loss = -beta * log_q

        if self.version == 'v2':
            p_loss = -beta * (log_lik_r + log_prior_r)

        rem_loss = torch.sum(q_loss + p_loss, 0).sum()
        return rem_loss
示例#2
0
def binary_loss(output, x):
    """ Compute the negative log-likelihood of output parameters given the data.
    """
    p_L, p_T, p_R = output  # Unpack the parameters of the distributions

    # Define the distributions
    dist_L = Bernoulli(p_L)
    dist_T = Bernoulli(p_T)
    dist_R = Bernoulli(p_R)

    # Estimate the log-likelihoods
    NLL = -torch.mean(dist_L.log_prob(x[:,0].view(-1,1)) +
                      dist_T.log_prob(x[:,1].view(-1,1)) +
                      dist_R.log_prob(x[:,2].view(-1,1)))

    return NLL
示例#3
0
def continuous_outcome_loss(output, x):
    """ Compute the negative log-likelihood of output parameters given the data.
    """
    p_L, p_T, mu_R, log_sigma_R = output  # Unpack the parameters of the distributions
    sigma_R = torch.exp(log_sigma_R)  # Convert the log scale

    # Define the distributions
    dist_L = Bernoulli(p_L)
    dist_T = Bernoulli(p_T)
    dist_R = Normal(mu_R, sigma_R)

    # Estimate the log-likelihoods
    NLL = -torch.mean(dist_L.log_prob(x[:,0].view(-1,1)) +
                      dist_T.log_prob(x[:,1].view(-1,1)) +
                      dist_R.log_prob(x[:,2].view(-1,1)))

    return NLL
示例#4
0
文件: models.py 项目: mukami12/REM
    def log_lik(self, loader, n_samples):
        """Get log marginal estimate via importance sampling
        """
        nll = 0
        for i, (data, _) in enumerate(loader):
            data = data.view(-1, self.x_dim).to(device)
            bsz = data.size(0)
            mu, logvar = self.encoder(data)

            ### get moment-matched proposal
            q_phi = Normal(loc=mu, scale=torch.exp(0.5 * logvar))
            z_q = q_phi.rsample((n_samples, ))
            recon_batch = self.decoder(z_q)
            x_dist = Bernoulli(logits=recon_batch)
            log_lik = x_dist.log_prob(data).sum(-1)
            log_prior = self.prior.log_prob(z_q).sum(-1)
            log_q = q_phi.log_prob(z_q).sum(-1)
            log_w = log_lik + log_prior - log_q
            tmp_alpha = torch.logsumexp(log_w, dim=0).unsqueeze(0)
            alpha = torch.exp(log_w - tmp_alpha).detach()

            mu_r = alpha.unsqueeze(2) * z_q
            mu_r = mu_r.sum(0).detach()

            nll_proposal = Normal(loc=mu_r, scale=torch.exp(0.5 * logvar))

            bsz = data.size(0)

            z = nll_proposal.rsample((n_samples, ))
            recon_batch = self.decoder(z)
            x_dist = Bernoulli(logits=recon_batch)
            log_lik = x_dist.log_prob(data).sum(-1)
            log_prior = self.prior.log_prob(z).sum(-1)
            log_r = nll_proposal.log_prob(z).sum(-1)

            loss = log_lik + log_prior - log_r
            ll = torch.logsumexp(loss, dim=0) - math.log(n_samples)
            ll = ll.sum()
            nll += -ll.item()

            if i > 0 and i % 20000000 == 0:
                print('i: {}/{}'.format(i, len(loader)))

        nll /= len(loader.dataset)
        return nll
示例#5
0
def make_decisions(logits):

    dist1 = Bernoulli(logits=logits[:, 0])

    # Decision 1
    b1 = dist1.sample()
    logprob1 = dist1.log_prob(b1)

    if b1 == 0:
        dist2 = Bernoulli(logits=logits[:, 1])
    else:
        dist2 = Bernoulli(logits=logits[:, 2])

    # Decision 2
    b2 = dist2.sample()
    logprob2 = dist2.log_prob(b2)

    return b1, logprob1, b2, logprob2
示例#6
0
def make_decisions(logits):

    dist1 = Bernoulli(logits=logits[:,0])

    # Decision 1
    b1 = dist1.sample()
    logprob1 = dist1.log_prob(b1)

    if b1 ==0:
        dist2 = Bernoulli(logits=logits[:,1])
    else:
        dist2 = Bernoulli(logits=logits[:,2])

    # Decision 2
    b2 = dist2.sample()
    logprob2 = dist2.log_prob(b2)

    return b1, logprob1, b2, logprob2   
示例#7
0
    def compute_log_pdf_bernoulli(self, fs_samples, target_matrix):
        """

        :param fs_samples:
        :param target_matrix:
        :return:
        """
        dist = Bernoulli(torch.sigmoid(fs_samples))
        log_pdf = dist.log_prob(target_matrix)
        return log_pdf
    def reward_forward(self, prob, locations, orig_window_length, full_image,
                       other_full_image):
        """
        forward with policy gradient
        :param prob: probability maps
        :param locations: locations recording where the patches are extracted
        :param orig_window_length: original patches length to calculat the replication times
        :param full_image: ground truth full image
        :param other_full_image: another ground truth full image
        :return:
        """
        # Bernoulli samoling
        batch_size = prob.size(0)
        bernoulli_dist = Bernoulli(prob)
        samples = bernoulli_dist.sample()
        log_probs = bernoulli_dist.log_prob(samples)

        # put back
        with torch.no_grad():
            repeat_times = int(np.ceil(batch_size / orig_window_length))

            target_full_images = other_full_image.repeat(repeat_times, 1, 1, 1)
            inpaint_full_images = full_image.repeat(repeat_times, 1, 1, 1)

            # j th full image
            j = 0
            for batch_idx in range(batch_size):
                sample = samples[batch_idx]
                y1, x1, y2, x2 = locations[batch_idx]
                # sample = torch.where(sample >= 0.5, torch.ones_like(sample), torch.zeros_like(sample))
                inpaint_full_images[j, :, y1:y2, x1:x2] = sample.detach()

                if (batch_idx + 1) % orig_window_length == 0:
                    j += 1

            # calculate the reward over the re-composed root and ground truth root
            rewards = self.forward(inpaint_full_images, target_full_images)
            # broadcast the rewards to each element of the feature maps
            broadcast_rewards = torch.zeros(batch_size, 1)
            broadcast_rewards = broadcast_rewards.to(device)
            # j th full image
            j = 0
            for batch_idx in range(batch_size):
                broadcast_rewards[batch_idx] = rewards[j]
                if (batch_idx + 1) % orig_window_length == 0:
                    j += 1

        broadcast_rewards = broadcast_rewards.view(broadcast_rewards.size(0),
                                                   1, 1, 1)
        image_size = prob.size(2)
        broadcast_rewards = broadcast_rewards.repeat(1, 1, image_size,
                                                     image_size)

        return log_probs, broadcast_rewards
示例#9
0
 def action(self, x):
     x = T.from_numpy(x).double().unsqueeze(0)
     # x = x.double().unsqueeze(0)
     message_means, message_sds, action_probs = self.forward(x)
     action_dbn = Bernoulli(action_probs)
     action = action_dbn.sample()
     message_dbn = Normal(message_means, message_sds)
     message = message_dbn.sample()
     log_prob = action_dbn.log_prob(action) + message_dbn.log_prob(
         message).sum()
     x = T.cat((message[0, :], action[0].double()))
     return x, log_prob
示例#10
0
    def forward(self, target, output):
        """

        :param output: reconstructed input (B, C, W, H)
        :param target: initial input (B, C, W, H)
        :return: mean squared loss
        """
        dist = Bernoulli(logits=output)
        rec_loss = -dist.log_prob(target)
        rec_loss = torch.mean(rec_loss.sum(dim=[1, 2, 3]))

        return rec_loss
示例#11
0
    def forward(self, z, x=None):
        y = z

        y = F.relu(self.fc1(y))
        y = y.view(-1, 16, 5, 5)
        y = F.relu(self.conv1(y))
        y = F.relu(self.conv2(y))
        y = F.relu(self.conv3(y))
        y = self.conv2(y)
        y = y[:, 0, 1:, 1:]
        dist = Bernoulli(logits = y)
        if x is None: x = dist.sample()
        score = dist.log_prob(x.float())
        score = score.sum(dim=2).sum(dim=1)
        return x, score
示例#12
0
def calculate_likelihood(X, S):

    model.eval()

    N_obs = X.size(0)
    N_sim = S.size(1)
    S = torch.transpose(S, 1, 0)

    # parameters from the encoder
    z_mu, z_logvar = model.encoder(X)
    z_mu = z_mu.detach()

    z_logvar = z_logvar.detach()

    std = z_logvar.mul(0.5).exp_()
    std = std.detach()
    #We simulate the hidden state for the K number of simulations
    z = z_mu + std * S
    z = z.detach()

    RE = torch.zeros((N_sim, N_obs)).cuda()

    for j in range(N_sim):
        ber = Bernoulli(model.decoder(z[j, :, :]).detach())
        RE[j] = ber.log_prob(X).sum(1).detach()
    RE = RE.detach()
    log_norm = MultivariateNormal(
        torch.zeros(100).cuda(),
        torch.diag(torch.ones(100)).cuda())

    log_p_z = log_norm.log_prob(z)
    log_p_z = log_p_z.detach()
    log_mult = Normal(z_mu, std)

    log_q_z = log_mult.log_prob(z).sum(2)
    log_q_z = log_q_z.detach()
    KL = -(log_p_z - log_q_z)

    L = (RE - KL)

    log_lik = logsumexp(L.detach().cpu().numpy(), axis=0)

    log_lik = (log_lik - np.log(N_sim))

    return log_lik
示例#13
0
    def f(self, x, z, logits, hard=False):

        B = x.shape[0]

        # image likelihood given b
        # b = harden(z).detach()
        x_hat = self.generator.forward(z)
        alpha = torch.sigmoid(x_hat)
        beta = Beta(alpha*self.beta_scale, (1.-alpha)*self.beta_scale)
        x_noise = torch.clamp(x + torch.FloatTensor(x.shape).uniform_(0., 1./256.).cuda(), min=1e-5, max=1-1e-5)
        logpx = beta.log_prob(x_noise) #[120,3,112,112]  # add uniform noise here
        logpx = torch.sum(logpx.view(B, -1),1) # [PB]  * self.w_logpx

        # prior is constant I think 
        # for q(b|x), we just want to increase its entropy 
        if hard:
            dist = Bernoulli(logits=logits)
        else:
            dist = RelaxedBernoulli(torch.Tensor([1.]).cuda(), logits=logits)
            
        logqb = dist.log_prob(z.detach())
        logqb = torch.sum(logqb,1)

        return logpx, logqb, alpha
#                         early_stop=5)
#     print ('Done training\n')
# # fada
# else:
#     # net_relax.load_params_v3(save_dir=home+'/Downloads/tmmpp/', step=30551, name='') #.499
#     net_relax.load_params_v3(save_dir=home+'/Documents/Grad_Estimators/new/', step=1607, name='') #.4
# print()

dist = Bernoulli(bern_param)
samps = []
grads = []
logprobgrads = []
for i in range(n):
    samp = dist.sample()

    logprob = dist.log_prob(samp.detach())
    logprobgrad = torch.autograd.grad(outputs=logprob,
                                      inputs=(bern_param),
                                      retain_graph=True)[0]
    # print (samp.data.numpy(), logprob.data.numpy(), logprobgrad.data.numpy())
    # fsdfa

    samps.append(samp.numpy())
    grads.append((f(samp.numpy()) - 0.) * logprobgrad.numpy())
    logprobgrads.append(logprobgrad.numpy())

# print (grads[:10])

print('Grad Estimator: REINFORCE')
print('Avg samp', np.mean(samps))
print('Grad mean', np.mean(grads))
reinforce_cat_grad_stds = []
for theta in thetas:

    print()
    print('theta:', theta)
    # theta = .01 #.99 #.1 #95 #.3 #.9 #.05 #.3
    bern_param = torch.tensor([theta], requires_grad=True)

    dist = Bernoulli(bern_param)
    samps = []
    grads = []
    logprobgrads = []
    for i in range(n):
        samp = dist.sample()

        logprob = dist.log_prob(samp.detach())
        logprobgrad = torch.autograd.grad(outputs=logprob,
                                          inputs=(bern_param),
                                          retain_graph=True)[0]
        # print (samp.data.numpy(), logprob.data.numpy(), logprobgrad.data.numpy())
        # fsdfa

        samps.append(samp.numpy())
        grads.append((f(samp.numpy()) - 0.) * logprobgrad.numpy())
        logprobgrads.append(logprobgrad.numpy())

    # print (grads[:10])

    print('Grad Estimator: REINFORCE')
    # print ('Avg samp', np.mean(samps))
    print('Grad mean', np.mean(grads))
示例#16
0
    def forward(self, grad_est_type, x=None, warmup=1., inf_net=None): #, k=1): #, marginf_type=0):

        outputs = {}
        B = x.shape[0]

        #Samples from relaxed bernoulli 
        z, logits, logqz = self.q.sample(x) 

        if isnan(logqz).any():
            print(torch.sum(isnan(logqz).float()).data.item())
            print(torch.mean(logits).data.item())
            print(torch.max(logits).data.item())
            print(torch.min(logits).data.item())
            print(torch.max(z).data.item())
            print(torch.min(z).data.item())
            fdsfad

        
        # Compute discrete ELBO
        b = harden(z).detach()
        logpx_b, logq_b, alpha1 = self.f(x, b, logits, hard=True)
        fhard = (logpx_b - logq_b).detach()
        

        if grad_est_type == 'SimpLAX':
            # Control Variate
            logpx_z, logq_z, alpha2 = self.f(x, z, logits, hard=False)
            fsoft = logpx_z.detach() #- logq_z
            c = self.surr(x, z).view(B)

            # REINFORCE with Control Variate
            Adv = (fhard - fsoft - c).detach()
            cost1 = Adv * logqz

            # Unbiased gradient of fhard/elbo
            cost_all = cost1 + c + fsoft # + logpx_b

            # Surrogate loss
            surr_cost = torch.abs(fhard - fsoft - c)#**2



        elif grad_est_type == 'RELAX':

            #p(z|b)
            theta = logit_to_prob(logits)
            v = torch.rand(z.shape[0], z.shape[1]).cuda()
            v_prime = v * (b - 1.) * (theta - 1.) + b * (v * theta + 1. - theta)
            # z_tilde = logits.detach() + torch.log(v_prime) - torch.log1p(-v_prime)
            z_tilde = logits + torch.log(v_prime) - torch.log1p(-v_prime)
            z_tilde = torch.sigmoid(z_tilde)

            # Control Variate
            logpx_z, logq_z, alpha2 = self.f(x, z, logits, hard=False)
            fsoft = logpx_z.detach() #- logq_z
            c_ztilde = self.surr(x, z_tilde).view(B)
            c_z = self.surr(x, z).view(B)

            # REINFORCE with Control Variate
            dist_bern = Bernoulli(logits=logits)
            logqb = dist_bern.log_prob(b.detach())
            logqb = torch.sum(logqb,1)

            Adv = (fhard - fsoft - c_ztilde).detach()
            cost1 = Adv * logqb

            # Unbiased gradient of fhard/elbo
            cost_all = cost1 + fsoft + c_z - c_ztilde#+ logpx_b

            # Surrogate loss
            surr_cost = torch.abs(fhard - fsoft - c_ztilde)#**2




        elif grad_est_type == 'SimpLAX_nosoft':
            # Control Variate
            logpx_z, logq_z, alpha2 = self.f(x, z, logits, hard=False)
            # fsoft = logpx_z.detach() #- logq_z
            c = self.surr(x, z).view(B)

            # REINFORCE with Control Variate
            Adv = (fhard - c).detach()
            cost1 = Adv * logqz

            # Unbiased gradient of fhard/elbo
            cost_all = cost1 + c  # + logpx_b

            # Surrogate loss
            surr_cost = torch.abs(fhard - c)#**2



        elif grad_est_type == 'RELAX_nosoft':

            #p(z|b)
            theta = logit_to_prob(logits)
            v = torch.rand(z.shape[0], z.shape[1]).cuda()
            v_prime = v * (b - 1.) * (theta - 1.) + b * (v * theta + 1. - theta)
            z_tilde = logits + torch.log(v_prime) - torch.log1p(-v_prime)
            z_tilde = torch.sigmoid(z_tilde)

            # Control Variate
            logpx_z, logq_z, alpha2 = self.f(x, z, logits, hard=False)
            # fsoft = logpx_z.detach() #- logq_z
            c_ztilde = self.surr(x, z_tilde).view(B)
            c_z = self.surr(x, z).view(B)

            # REINFORCE with Control Variate
            dist_bern = Bernoulli(logits=logits)
            logqb = dist_bern.log_prob(b.detach())
            logqb = torch.sum(logqb,1)

            Adv = (fhard - c_ztilde).detach()

            # print (Adv.shape, logqb.shape)
            cost1 = Adv * logqb

            # Unbiased gradient of fhard/elbo
            # print (cost1.shape, c_z.shape, c_ztilde.shape)
            # fsdf
            cost_all = cost1 + c_z - c_ztilde#+ logpx_b

            # Surrogate loss
            surr_cost = torch.abs(fhard - c_ztilde)#**2






        # Confirm generator grad isnt in encoder grad
        # logprobgrad = torch.autograd.grad(outputs=torch.mean(fhard), inputs=(logits), retain_graph=True)[0]
        # print (logprobgrad.shape, torch.max(logprobgrad), torch.min(logprobgrad))

        # logprobgrad = torch.autograd.grad(outputs=torch.mean(fsoft), inputs=(logits), retain_graph=True)[0]
        # print (logprobgrad.shape, torch.max(logprobgrad), torch.min(logprobgrad))
        # fsdfads


        outputs['logpx'] = torch.mean(logpx_b)
        outputs['x_recon'] = alpha1
        # outputs['welbo'] = torch.mean(logpx + warmup*( logpz - logqz))
        outputs['welbo'] = torch.mean(cost_all) #torch.mean(logpx_b + warmup*(KL))
        outputs['elbo'] = torch.mean(logpx_b - logq_b - 138.63)
        # outputs['logws'] = log_ws
        outputs['z'] = z
        outputs['logpz'] = torch.zeros(1) #torch.mean(logpz)
        outputs['logqz'] = torch.mean(logq_b)
        outputs['surr_cost'] = torch.mean(surr_cost)

        outputs['fhard'] = torch.mean(fhard)
        # outputs['fsoft'] = torch.mean(fsoft)
        # outputs['c'] = torch.mean(c)
        outputs['logq_z'] = torch.mean(logq_z)
        outputs['logits'] = logits

        return outputs
示例#17
0
def train(model_name='RL_net', tbd='logs', sparsity_lambda=0.5, ch=3):

    no_epochs = 30
    lr = 1e-3
    batch_size = 1024

    print('Loading dataset ...')
    train_loader, valid_loader, test_loader = get_dataset(
        minibatch_size=batch_size)
    fraud_net = load_fraudnet()
    net = RLNet(ch=ch).to(device='cuda').train()
    print('Model:')
    print(net)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    best_valid_reward = -1 * float('inf')
    writer = SummaryWriter('RL_runs/' + tbd)

    for i in range(no_epochs):

        for b, data in enumerate(train_loader):

            inputs, labels = data
            inputs, labels = smote_func(inputs, labels)

            y_logits = net(inputs)
            y_probs = F.sigmoid(y_logits)

            m = Bernoulli(probs=y_probs)
            selected_features = m.sample()
            number_selected = selected_features.sum(1)  #fraction selected

            # print (selected_features,default_features)
            selected_inputs = inputs * selected_features + (
                1 - selected_features) * default_features

            y_pred = fraud_net(selected_inputs)

            bce_loss = nn.BCELoss(reduction='none')(y_pred, labels)

            # print('number_selected, bce_loss',number_selected.shape,bce_loss.shape)
            number_dropped = 50 - number_selected
            reward = -1 * bce_loss - sparsity_lambda * torch.abs(
                number_selected - 5).unsqueeze(1)

            log_probs = m.log_prob(selected_features)
            print(log_probs.shape, reward.shape)
            loss = -log_probs * reward
            loss = loss.mean()

            if b % 100:
                print(
                    'Epochs: {}, batch: {}, loss: {}, reward: {}, bce loss: {}, fraction selected: {}'
                    .format(i, b, loss,
                            reward.mean().item(),
                            bce_loss.mean().item(),
                            number_selected.mean().item()))
                sys.stdout.flush()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        writer.add_scalar('train_loss', loss, i)
        writer.add_scalar('reward', reward.mean(), i)
        writer.add_scalar('bce_loss', bce_loss.mean(), i)
        writer.add_scalar('number_selected', number_selected.mean(), i)

        valid_result = evaluate(valid_loader, fraud_net, net, sparsity_lambda)
        valid_result.update({'name': 'Valid_Epoch_{}'.format(i)})
        print(valid_result)

        valid_reward = valid_result['reward']

        if valid_reward > best_valid_reward:
            torch.save(net.state_dict(),
                       './saved_rl_checkpoints/' + model_name + '.th')
            best_valid_reward = valid_reward

            test_result = evaluate(test_loader, fraud_net, net,
                                   sparsity_lambda)
            test_result.update({'name': 'test_Epoch_{}'.format(i)})
            print(test_result)

            print('Saving checkpoint at epoch: {}'.format(i))

    print(test_result)
示例#18
0
    steps = []
    losses= []
    for step in range(total_steps):

        dist = Bernoulli(logits=bern_param)

        optim.zero_grad()

        bs = []
        for i in range(20):
            samps = dist.sample()
            bs.append(H(samps))
        bs = torch.FloatTensor(bs).unsqueeze(1)

        logprob = dist.log_prob(bs)
        # logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]

        loss = torch.mean(f(bs) * logprob)

        #review the pytorch_toy and the RL code to see how PG was done 

        loss.backward()  
        optim.step()

        if step%50 ==0:
            if step %500==0:
                print (step, torch.mean(f(bs)).numpy(), bern_param.detach().numpy(), logit_to_prob(bern_param).detach().numpy())
            losses.append(torch.mean(f(bs)).numpy())
            steps.append(step)







dist = Bernoulli(bern_param)
samps = []
grads = []
logprobgrads = []
for i in range(n):
    samp = dist.sample()

    logprob = dist.log_prob(samp.detach())
    logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
    # print (samp.data.numpy(), logprob.data.numpy(), logprobgrad.data.numpy())
    # fsdfa

    samps.append(samp.numpy())
    grads.append( (f(samp.numpy()) - 0.) * logprobgrad.numpy())
    logprobgrads.append(logprobgrad.numpy())


# print (grads[:10])

print ('Grad Estimator: REINFORCE')
print ('Avg samp', np.mean(samps))
print ('Grad mean', np.mean(grads))
print ('Grad std', np.std(grads))
示例#20
0
            optim_NN.zero_grad()

            losses = 0
            for ii in range(10):
                #Sample p(z)
                z = sample_Gumbel(probs=torch.exp(logits))

                b = H(z)

                #Sample p(z|b)
                z_tilde = sample_conditional_Gumbel(probs=torch.exp(logits),
                                                    b=b)

                dist_bern = Bernoulli(logits=logits)
                logpb = dist_bern.log_prob(b)

                f_b = f(b)
                pred = surrogate.net(z)
                # print (z)

                NN_loss = torch.mean((f_b - pred)**2)

                losses += NN_loss

            losses.backward()
            optim_NN.step()

            zs.append(to_print(z)[0])

            if step % 50 == 0:
示例#21
0
def evaluate(test_loader, fraudnet, net, sparsity_lambda):

    total_loss = 0.0
    total_reward = 0.0
    total_bce_loss = 0.0
    total_number_selected = 0.0
    steps = 0.0
    net = net.eval()

    labels_list = []
    predicted_list = []
    loss_list = []

    with torch.no_grad():
        for b, data in enumerate(test_loader):

            inputs, labels = data
            inputs, labels = smote_func(inputs, labels)

            y_logits = net(inputs)
            y_probs = F.sigmoid(y_logits)

            m = Bernoulli(probs=y_probs)
            selected_features = m.sample()
            number_selected = selected_features.sum(1)  #fraction selected

            selected_inputs = inputs * selected_features + (
                1 - selected_features) * default_features
            y_pred = fraudnet(selected_inputs)

            bce_loss = nn.BCELoss(reduction='none')(y_pred, labels)

            number_dropped = 50 - number_selected
            reward = -1 * bce_loss - sparsity_lambda * torch.abs(
                number_selected - 5).unsqueeze(1)

            log_probs = m.log_prob(selected_features)
            loss = -log_probs * reward

            total_loss += loss.mean().item()
            total_bce_loss += bce_loss.mean().item()
            total_reward += reward.mean().item()
            total_number_selected += number_selected.mean().item()
            steps += 1

            predicted_list.append(y_pred.cpu().data.numpy())
            labels_list.append(labels.cpu().data.numpy())
            loss_list.append(bce_loss.cpu().data.numpy())

    predicted_list = np.array([x for y in predicted_list for x in y])
    labels_list = np.array([x for y in labels_list for x in y])
    actual_labels = (labels_list >= 0.5).astype(np.int32)
    loss_list = np.array([x for y in loss_list for x in y])
    positive_loss = loss_list[labels_list >= 0.5].mean()
    negative_loss = loss_list[labels_list < 0.5].mean()
    overall_loss = loss_list.mean()

    result = calc_metrics_classification(actual_labels, predicted_list)
    result.update({
        'positive_bce_loss': positive_loss,
        'negative_bce_loss': negative_loss,
        'overall_bce_loss': overall_loss
    })

    total_loss /= steps
    total_reward /= steps
    total_number_selected /= steps
    result.update({
        'loss': total_loss,
        'reward': total_reward,
        'number_selected': total_number_selected
    })
    return result
示例#22
0
 def logpdf(self, f, y):
     sigmoid = torch.nn.Sigmoid()
     p = sigmoid(f).flatten()
     bernoulli = Ber(probs=p)
     logpdf = bernoulli.log_prob(y)
     return logpdf
    bern_param = torch.tensor([theta], requires_grad=True)







    dist = Bernoulli(bern_param)
    samps = []
    grads = []
    logprobgrads = []
    for i in range(n):
        samp = dist.sample()

        logprob = dist.log_prob(samp.detach())
        logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
        # print (samp.data.numpy(), logprob.data.numpy(), logprobgrad.data.numpy())
        # fsdfa

        samps.append(samp.numpy())
        grads.append( (f(samp.numpy()) - 0.) * logprobgrad.numpy())
        logprobgrads.append(logprobgrad.numpy())


    # print (grads[:10])

    print ('Grad Estimator: REINFORCE')
    # print ('Avg samp', np.mean(samps))
    print ('Grad mean', np.mean(grads))
    print ('Grad std', np.std(grads))
示例#24
0
B = 1
C = 3
N = 2000

prelogits = torch.zeros([B, C])
logits = prelogits - logsumexp(prelogits)
# logits = torch.tensor(logits.clone().detach(), requires_grad=True)
logits.requires_grad_(True)

grads = []
for i in range(N):
    dist1 = Bernoulli(logits=logits[:, 0])

    # Decision 1
    b1 = dist1.sample()
    logprob1 = dist1.log_prob(b1)

    if b1 == 0:
        dist2 = Bernoulli(logits=logits[:, 1])
    else:
        dist2 = Bernoulli(logits=logits[:, 2])

    # Decision 2
    b2 = dist2.sample()
    logprob2 = dist2.log_prob(b2)

    if b1 == 0 and b2 == 0:
        reward = 1
    elif b1 == 0 and b2 == 1:
        reward = 2
    elif b1 == 1 and b2 == 0:
示例#25
0
pz_grad_stds = []
for theta in thetas:

    #     print ()
    print('theta:', theta)
    #     # theta = .01 #.99 #.1 #95 #.3 #.9 #.05 #.3
    bern_param = torch.tensor([theta], requires_grad=True)

    dist = Bernoulli(bern_param)
    samps = []
    grads = []
    logprobgrads = []
    for i in range(n):
        samp = dist.sample()

        logprob = dist.log_prob(samp.detach())
        logprobgrad = torch.autograd.grad(outputs=logprob,
                                          inputs=(bern_param),
                                          retain_graph=True)[0]
        # print (samp.data.numpy(), logprob.data.numpy(), logprobgrad.data.numpy())
        # fsdfa

        samps.append(samp.numpy())
        grads.append((f(samp.numpy()) - 0.) * logprobgrad.numpy())
        logprobgrads.append(logprobgrad.numpy())

    # print (grads[:10])

    print('Grad Estimator: REINFORCE')
    # print ('Avg samp', np.mean(samps))
    print('Grad mean', np.mean(grads))
示例#26
0
    def inference_step(self, prev, x):
        """
        Given previous (or initial) state and input image, predict the next
        inference step (next object).
        """

        bs = x.size(0)
        
        # Flatten the image
        x_flat = x.view(bs, -1)
        
        # Feed (x, z_{<t}) through the LSTM cell, get encoding h
        lstm_input = torch.cat(
            (x_flat, prev.z_where, prev.z_what, prev.z_pres), dim=1)
        h, c = self.lstm(lstm_input, (prev.h, prev.c))

        # Predictor presence and location from h
        z_pres_p, z_where_loc, z_where_scale = self.predictor(h)
        
        # If previous z_pres is 0, force z_pres to 0
        z_pres_p = z_pres_p * prev.z_pres
        
        # Numerical stability
        eps = 1e-12
        z_pres_p = z_pres_p.clamp(min=eps, max=1.0-eps)

        # sample z_pres
        z_pres_post = Bernoulli(z_pres_p)
        z_pres = z_pres_post.sample()

        # If previous z_pres is 0, then this z_pres should also be 0.
        # However, this is sampled from a Bernoulli whose probability is at
        # least eps. In the unlucky event that the sample is 1, we force this
        # to 0 as well.
        z_pres = z_pres * prev.z_pres
        
        # Likelihood: log q(z_pres[i] | x, z_{<i}) (if z_pres[i-1]=1, else 0)
        # Mask with prev.z_pres instead of z_pres, i.e. if already at the
        # previous step there was no object.
        z_pres_likelihood = z_pres_post.log_prob(z_pres) * prev.z_pres
        z_pres_likelihood = z_pres_likelihood.squeeze()  # shape (B,)

        # Sample z_where
        z_where_post = Normal(z_where_loc, z_where_scale)
        z_where = z_where_post.rsample()
        
        # Get object from image - shape (B, 1, Hobj, Wobj)
        obj = self.spatial_transf.inverse(x, z_where)
        
        # Predictor z_what
        z_what_loc, z_what_scale = self.encoder(obj)
        z_what_post = Normal(z_what_loc, z_what_scale)
        z_what = z_what_post.rsample()

        # Compute baseline for this z_pres:
        # b_i(z_{<i}) depending on previous step latent variables only.
        bl_h, bl_c = self.bl_lstm(lstm_input.detach(), (prev.bl_h, prev.bl_c))
        baseline_value = self.bl_regressor(bl_h).squeeze()  # shape (B,)

        # The baseline is not used if z_pres[t-1] is 0 (object not present in
        # the previous step). Mask it out to be on the safe side.
        baseline_value = baseline_value * prev.z_pres.squeeze()
        
        # KL for the current step, sum over data dimension: shape (B,)
        kl_pres = kl_divergence(
            z_pres_post,
            self.pres_prior.expand(z_pres_post.batch_shape)).sum(1)
        kl_where = kl_divergence(
            z_where_post,
            self.where_prior.expand(z_where_post.batch_shape)).sum(1)
        kl_what = kl_divergence(
            z_what_post,
            self.what_prior.expand(z_what_post.batch_shape)).sum(1)

        # When z_pres[i] is 0, zwhere and zwhat are not used -> set KL=0
        kl_where = kl_where * z_pres.squeeze()
        kl_what = kl_what * z_pres.squeeze()

        # When z_pres[i-1] is 0, zpres is not used -> set KL=0
        kl_pres = kl_pres * prev.z_pres.squeeze()
        
        kl = (kl_pres + kl_where + kl_what)

        # New state
        new_state = State(
            z_pres=z_pres,
            z_where=z_where,
            z_what=z_what,
            h=h,
            c=c,
            bl_c=bl_c,
            bl_h=bl_h,
            )

        out = {
            'state': new_state,
            'kl': kl,
            'kl_pres': kl_pres,
            'kl_where': kl_where,
            'kl_what': kl_what,
            'baseline_value': baseline_value,
            'z_pres_likelihood': z_pres_likelihood,
        }
        return out
示例#27
0
    def propagate(self, x, x_embed, prev_relation, prev_temporal):
        """
        Propagate step for a single object in a single time step.
        
        In this process, even empty objects are encoded in relation h's. This
        can be avoided by directly passing the previous relation state on to the
        next spatial step. May do this later.
        
        Args:
            x: original image. Size (B, C, H, W)
            x_embed: extracted image feature. Size (B, N)
            prev_relation: see RelationState
            prev_temporal: see TemporalState
            
        Returns:
            temporal_state: TemporalState
            relation_state:  RelationState
            kl: kl divergence for all z's. (B, 1)
            z_pres_likelihood: q(z_pres|x). (B, 1)
        """
        
        # First, encode relation and temporal info to get current h^{T, i}_t and
        # current h^{R, i}_t
        # Each being (B, N)
        h_rel, c_rel = self.rnn_relation(prev_relation.object.get_encoding(),
                                         (prev_relation.h, prev_relation.c))
        h_tem, c_tem = self.rnn_temporal(prev_temporal.object.get_encoding(),
                                         (prev_temporal.h, prev_temporal.c))
        
        # Compute proposal region to look at
        # (B, 4)
        proposal_region_delta = self.propagate_proposal(h_tem)
        proposal_region = prev_temporal.object.z_where + proposal_region_delta
        # self.i += 1
        # if self.i % 1000 == 0:
        #     print(proposal_region[0])
        proposal = self.image_to_glimpse(x, proposal_region, inverse=True)
        proposal_embed = self.proposal_embedding(proposal)
        
        # (B, N)
        # Predict where and pres, using h^T, h^T and x
        predict_input = torch.cat((h_rel, h_tem, proposal_embed), dim=-1)
        # (B, 4), (B, 4), (B, 1)
        # Note we only predict delta here.
        z_where_delta_loc, z_where_delta_scale, z_pres_prob = self.propagate_predict(predict_input)
        
        # Sample from z_pres posterior. Shape (B, 1)
        # NOTE: don't use zero probability otherwise log q(z|x) will not work
        z_pres_post = Bernoulli(z_pres_prob)
        z_pres = z_pres_post.sample()
        # Mask z_pres. You don't have do this for where and what because
        #   - their KL will be masked
        #   - they will be masked when computing the likelihood
        z_pres = z_pres * prev_temporal.object.z_pres
        
        # Sample from z_where posterior, (B, 4)
        z_where_delta_post = Normal(z_where_delta_loc, z_where_delta_scale)
        z_where_delta = z_where_delta_post.rsample()
        z_where = prev_temporal.z_where + z_where_delta
        # Mask
        z_where = z_where * z_pres
        
        # Extract glimpse from x, shape (B, 1, H, W)
        glimpse = self.image_to_glimpse(x, z_where, inverse=True)
        # This is important for handling overlap
        glimpse_mask = self.glimpse_mask(h_tem)
        glimpse = glimpse * glimpse_mask
        
        # Compute postribution over z_what and sample
        z_what_delta_loc, z_what_delta_scale = self.propagate_encoder(glimpse, h_tem, h_rel)
        z_what_delta_post = Normal(z_what_delta_loc, z_what_delta_scale)
        # (B, N)
        z_what_delta = z_what_delta_post.rsample()
        z_what = prev_temporal.z_what + z_what_delta
        # Mask
        z_what = z_what * z_pres
        
        # Now we compute KL divergence and discrete likelihood. Before that, we
        # will need to compute the recursive prior. This is parametrized by
        # previous object state (z) and hidden states from LSTM.
        
        
        # Compute prior for current step
        (z_what_delta_loc_prior, z_what_delta_scale_prior, z_where_delta_loc_prior,
            z_where_delta_scale_prior, z_pres_prob_prior) = (
            self.propagate_prior(h_tem))
        
        # TODO: demand that scale to be small to guarantee consistency
        if DEBUG:
            z_what_delta_loc_prior = arch.z_what_delta_loc_prior.expand_as(z_what_delta_loc_prior)
            z_what_delta_scale_prior = arch.z_what_delta_scale_prior.expand_as(z_what_delta_scale_prior)
            z_where_delta_scale_prior = arch.z_where_delta_scale_prior.expand_as(z_where_delta_scale_prior)


        # Construct prior distributions
        z_what_delta_prior = Normal(z_what_delta_loc_prior, z_what_delta_scale_prior)
        z_where_delta_prior = Normal(z_where_delta_loc_prior, z_where_delta_scale_prior)
        z_pres_prior = Bernoulli(z_pres_prob_prior)
        
        # Compute KL divergence. Each (B, N)
        kl_z_what = kl_divergence(z_what_delta_post, z_what_delta_prior)
        kl_z_where = kl_divergence(z_where_delta_post, z_where_delta_prior)
        kl_z_pres = kl_divergence(z_pres_post, z_pres_prior)
        
        # Mask these terms.
        # Note for kl_z_pres, we will need to use z_pres of previous time step.
        # This this because even z_pres[t] = 0, p(z_pres[t] | z_prse[t-1])
        # cannot be ignored.
        
        # Also Note that we do not mask where and what here. That's OK since We will
        # mask the image outside of the function later.
        
        kl_z_what = kl_z_what * z_pres
        kl_z_where = kl_z_where * z_pres
        kl_z_pres = kl_z_pres * prev_temporal.object.z_pres
        
        vis_logger['kl_pres_list'].append(kl_z_pres.mean())
        vis_logger['kl_what_list'].append(kl_z_what.mean())
        vis_logger['kl_where_list'].append(kl_z_where.mean())

        # (B,) here, after reduction
        kl = kl_z_what.sum(dim=-1) + kl_z_where.sum(dim=-1) + kl_z_pres.sum(dim=-1)
        
        # Finally, we compute the discrete likelihoods.
        z_pres_likelihood = z_pres_post.log_prob(z_pres)
        
        # Note we also need to mask some of this terms since they do not depend
        # on model parameter. (B, 1) here
        z_pres_likelihood = z_pres_likelihood * prev_temporal.object.z_pres
        z_pres_likelihood = z_pres_likelihood.squeeze()
        
        # Compute id. If z_pres is 1, then inherit that id. Otherwise set it to
        # zero
        id = prev_temporal.object.id * z_pres
        
        B = x.size(0)
        # Collect terms into new states.
        object_state = ObjectState(z_pres, z_where, z_what, id, z_pres_prob=z_pres_prob, object_enc=glimpse.view(B, -1), mask=glimpse_mask.view(B, -1), proposal=proposal_region)
        temporal_state = TemporalState(object_state, h_tem, c_tem)
        relation_state = RelationState(object_state, h_rel, c_rel)
        
        return temporal_state, relation_state, kl, z_pres_likelihood
    

#     print ()
    print ('theta:', theta)
#     # theta = .01 #.99 #.1 #95 #.3 #.9 #.05 #.3
    bern_param = torch.tensor([theta], requires_grad=True)


    dist = Bernoulli(bern_param)
    samps = []
    grads = []
    logprobgrads = []
    for i in range(n):
        samp = dist.sample()

        logprob = dist.log_prob(samp.detach())
        logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
        # print (samp.data.numpy(), logprob.data.numpy(), logprobgrad.data.numpy())
        # fsdfa

        samps.append(samp.numpy())
        grads.append( (f(samp.numpy()) - 0.) * logprobgrad.numpy())
        logprobgrads.append(logprobgrad.numpy())


    # print (grads[:10])

    print ('Grad Estimator: REINFORCE')
    # print ('Avg samp', np.mean(samps))
    print ('Grad mean', np.mean(grads))
    print ('Grad std', np.std(grads))
示例#29
0
    def discover(self, x, x_embed, prev_relation):
        """
        Discover step for a single object in a single time step.
        
        This is basically the same as propagate, but without a temporal state
        input. However, to share the same Predict module, we will use an empty
        temporal state instead.
        
        There are multiple code replication here. I do this because refactoring
        will not be a good abstraction.
        
        Args:
            x: original image. Size (B, C, H, W)
            x_embed: extracted image feature. Size (B, N)
            prev_relation: see RelationState
            
        Returns:
            temporal_state: TemporalState
            relation_state:  RelationState
            kl: kl divergence for all z's. (B,)
            z_pres_likelihood: q(z_pres|x). (B,)
        """
        # First, encode relation info to get current h^{R, i}_t
        # Each being (B, N)
        h_rel, c_rel = self.rnn_relation(prev_relation.object.get_encoding(),
                                         (prev_relation.h, prev_relation.c))
        # (B, N)
        # Predict where and pres, using h^R, and x
        predict_input = torch.cat((h_rel, x_embed), dim=-1)
        # (B, 4), (B, 4), (B, 1)
        z_where_loc, z_where_scale, z_pres_prob = self.discover_predict(predict_input)

        # Sample from z_pres posterior. Shape (B, 1)
        # NOTE: don't use zero probability otherwise log q(z|x) will not work
        z_pres_post = Bernoulli(z_pres_prob)
        z_pres = z_pres_post.sample()
        # Mask z_pres. You don't have do this for where and what because
        #   - their KL will be masked
        #   - they will be masked when computing the likelihood
        z_pres = z_pres * prev_relation.object.z_pres

        # Sample from z_where posterior, (B, 4)
        z_where_post = Normal(z_where_loc, z_where_scale)
        z_where = z_where_post.rsample()
        # Mask
        z_where = z_where * z_pres

        # Extract glimpse from x, shape (B, 1, H, W)
        glimpse = self.image_to_glimpse(x, z_where, inverse=True)

        # Compute postribution over z_what and sample
        z_what_loc, z_what_scale = self.discover_encoder(glimpse)
        z_what_post = Normal(z_what_loc, z_what_scale)
        # (B, N)
        z_what = z_what_post.rsample()
        # Mask
        z_what = z_what * z_pres

        # Construct prior distributions
        z_what_prior = Normal(arch.z_what_loc_prior, arch.z_what_scale_prior)
        z_where_prior = Normal(arch.z_where_loc_prior, arch.z_where_scale_prior)
        z_pres_prior = Bernoulli(arch.z_pres_prob_prior)

        # Compute KL divergence. Each (B, N)
        kl_z_what = kl_divergence(z_what_post, z_what_prior)
        kl_z_where = kl_divergence(z_where_post, z_where_prior)
        kl_z_pres = kl_divergence(z_pres_post, z_pres_prior)

        # Mask these terms.
        # Note for kl_z_pres, we will need to use z_pres of previous time step.
        # This this because even z_pres[t] = 0, p(z_pres[t] | z_prse[t-1])
        # cannot be ignored.

        # Also Note that we do not mask where and what here. That's OK since We will
        # mask the image outside of the function later.

        kl_z_what = kl_z_what * z_pres
        kl_z_where = kl_z_where * z_pres
        kl_z_pres = kl_z_pres * prev_relation.object.z_pres
        
        vis_logger['kl_pres_list'].append(kl_z_pres.mean())
        vis_logger['kl_what_list'].append(kl_z_what.mean())
        vis_logger['kl_where_list'].append(kl_z_where.mean())
        
        # (B,) here, after reduction
        kl = kl_z_what.sum(dim=-1) + kl_z_where.sum(dim=-1) + kl_z_pres.sum(dim=-1)

        # Finally, we compute the discrete likelihoods.
        z_pres_likelihood = z_pres_post.log_prob(z_pres)

        # Note we also need to mask some of this terms since they do not depend
        # on model parameter. (B, 1) here
        z_pres_likelihood = z_pres_likelihood * prev_relation.object.z_pres
        z_pres_likelihood = z_pres_likelihood.squeeze()
        
        
        # Compute id. If z_pres = 1, highest_id += 1, and use that id. Otherwise
        # we do not change the highest id and set id to zero
        self.highest_id += z_pres
        id = self.highest_id * z_pres

        B = x.size(0)
        # Collect terms into new states.
        object_state = ObjectState(z_pres, z_where, z_what, id=id, z_pres_prob=z_pres_prob, object_enc=glimpse.view(B, -1), mask=torch.zeros_like(glimpse.view(B, -1)), proposal=torch.zeros_like(z_where))
        
        # For temporal and prior state, we will use the initial state.
        
        temporal_state = TemporalState.get_initial_state(B, object_state)
        relation_state = RelationState(object_state, h_rel, c_rel)


        return temporal_state, relation_state, kl, z_pres_likelihood
示例#30
0
    def infer_step(self, prev, x):
        """
        Given previous state, predict next state. We assume that z_pres is 1
        :param prev: AIRState
        :return: new_state, KL, baseline value, z_pres_likelihood
        """

        B = x.size(0)

        # Flatten x
        x_flat = x.view(B, -1)

        # First, compute h_t that encodes (x, z[1:i-1])
        lstm_input = torch.cat(
            (x_flat, prev.z_where, prev.z_what, prev.z_pres), dim=1)
        h, c = self.lstm_cell(lstm_input, (prev.h, prev.c))

        # Predict presence and location
        z_pres_p, z_where_loc, z_where_scale = self.predict(h)

        # In theory, if z_pres is 0, we don't need to continue computation. But
        # for batch processing, we will do this anyway.

        # sample z_pres
        z_pres_p = z_pres_p * prev.z_pres

        # NOTE: for numerical stability, if z_pres_p is 0 or 1, we will need to
        # clamp it to within (0, 1), or otherwise the gradient will explode
        eps = 1e-6
        z_pres_p = z_pres_p + eps * (z_pres_p == 0).float() - eps * (
            z_pres_p == 1).float()

        z_pres_post = Bernoulli(z_pres_p)
        z_pres = z_pres_post.sample()
        z_pres = z_pres * prev.z_pres

        # Likelihood. Note we must use prev.z_pres instead of z_pres because
        # p(z_pres[i]=0|z_prse[i]=1) is non-zero.
        z_pres_likelihood = z_pres_post.log_prob(z_pres) * prev.z_pres
        # (B,)
        z_pres_likelihood = z_pres_likelihood.squeeze()

        # sample z_where
        z_where_post = Normal(z_where_loc, z_where_scale)
        z_where = z_where_post.rsample()

        # extract object
        # (B, 1, Hobj, Wobj)
        object = self.image_to_object(x, z_where, inverse=True)

        # predict z_what
        z_what_loc, z_what_scale = self.encoder(object)
        z_what_post = Normal(z_what_loc, z_what_scale)
        z_what = z_what_post.rsample()

        # compute baseline for this z_pres
        bl_h, bl_c = self.bl_rnn(lstm_input.detach(), (prev.bl_h, prev.bl_c))
        # (B,)
        baseline_value = self.bl_predict(bl_h).squeeze()
        # If z_pres[i-1] is 0, the reinforce term will not be dependent on phi.
        # In this case, we don't need the term. So we set it to zero.
        # At the same time, we must set learning signal to zero as this will
        # matter in baseline loss computation.
        baseline_value = baseline_value * prev.z_pres.squeeze()

        # Compute KL as we go, sum over data dimension
        kl_pres = kl_divergence(
            z_pres_post,
            self.pres_prior.expand(z_pres_post.batch_shape)).sum(1)
        kl_where = kl_divergence(
            z_where_post,
            self.where_prior.expand(z_where_post.batch_shape)).sum(1)
        kl_what = kl_divergence(
            z_what_post,
            self.what_prior.expand(z_what_post.batch_shape)).sum(1)

        # For where and what, when z_pres[i] is 0, they are determnisitic
        kl_where = kl_where * z_pres.squeeze()
        kl_what = kl_what * z_pres.squeeze()
        # For pres, this is not the case. So we use prev.z_pres.
        kl_pres = kl_pres * prev.z_pres.squeeze()

        kl = (kl_pres + kl_where + kl_what)

        # new state
        new_state = AIRState(z_pres=z_pres,
                             z_where=z_where,
                             z_what=z_what,
                             h=h,
                             c=c,
                             bl_c=bl_c,
                             bl_h=bl_h,
                             z_pres_p=z_pres_p)

        # Logging
        if DEBUG:
            vis_logger['z_pres_p_list'].append(z_pres_p[0])
            vis_logger['z_pres_list'].append(z_pres[0])
            vis_logger['z_where_list'].append(z_where[0])
            vis_logger['object_enc_list'].append(object[0])
            vis_logger['kl_pres_list'].append(kl_pres.mean())
            vis_logger['kl_what_list'].append(kl_what.mean())
            vis_logger['kl_where_list'].append(kl_where.mean())

        return new_state, kl, baseline_value, z_pres_likelihood
示例#31
0
    obs = torch.Tensor(env.reset()).to(device).view(-1, 4)
    print(
        'episode: ',
        i,
    )
    model.zero_grad()
    optimiser.zero_grad()
    for t in range(100):
        # env.render() # returns a true or false value depending upon the display
        p = model.forward(obs)  # 0 or 1 force left or right application
        m = Bernoulli(p)
        a = m.sample()
        action = a.bool().item()
        obs, reward, done, info = env.step(action)
        obs = torch.Tensor(obs).to(device).view(-1, 4)
        loss += m.log_prob(a)
        R += reward * (gamma**t)
        if done:
            # done is a boolean which indicates if it reaches the terminal state
            loss = -1 * loss * R
            loss.backward()
            optimiser.step()
            print("finished after {} timesteps".format(t + 1))
            reward_list.append(t + 1)
            break
env.close()

plt.figure()
plt.plot(reward_list, label='reward')
plt.legend()
plt.show()
示例#32
0
def log_density(actions, prob):
    m = Bernoulli(prob)
    return m.log_prob(actions).sum(1, keepdim=True)
    def train_generator(self, data_in, target_in, train=True, epoch=0):

        sorting_idx = get_sorting_index_with_noise_from_lengths(
            [len(x) for x in data_in], noise_frac=0.1)
        data = [data_in[i] for i in sorting_idx]
        target = [target_in[i] for i in sorting_idx]

        self.encoder.train()
        self.decoder.train()
        self.generator.train()
        bsize = self.bsize
        N = len(data)
        loss_total = 0

        batches = list(range(0, N, bsize))
        total_iter = len(batches)
        batches = shuffle(batches)
        predictions = []

        for idx, n in enumerate(batches):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            probs = self.generator(batch_data)
            m = Bernoulli(probs=probs)
            rationale = m.sample().squeeze(-1)
            batch_data.seq = batch_data.seq * rationale.long()  #(B,L)
            masks = batch_data.masks.float()

            with torch.no_grad():
                self.encoder(batch_data)
                self.decoder(batch_data)

                batch_target = target[n:n + bsize]
                batch_target = torch.Tensor(batch_target).to(device)

                if len(batch_target.shape) == 1:  #(B, )
                    batch_target = batch_target.unsqueeze(-1)  #(B, 1)

                bce_loss = self.criterion(batch_data.predict, batch_target)
                weight = batch_target * self.pos_weight + (1 - batch_target)
                bce_loss = (bce_loss * weight).mean(1)

                predict = torch.sigmoid(
                    batch_data.predict).cpu().data.numpy().tolist()
                predictions.append(predict)

            lengths = (batch_data.lengths - 2)  #excl <s> and <eos>
            temp = (1 - rationale) * (1 - masks)
            sparsity_reward = temp.sum(1) / (lengths.float())
            total_reward = -1 * bce_loss + self.configuration['model'][
                'generator']['sparsity_lambda'] * sparsity_reward

            log_probs = m.log_prob(rationale.unsqueeze(-1)).squeeze(-1)
            loss = -log_probs * total_reward.unsqueeze(-1)
            loss = loss.sum(1).mean(0)

            if train:
                self.generator_optim.zero_grad()
                loss.backward()
                self.generator_optim.step()
                print(
                    "Epoch: {}, Step: {} Loss {}, Total Reward: {}, BCE loss: {} Sparsity Reward: {} (sparsity_lambda = {})"
                    .format(
                        epoch, idx, loss, total_reward.mean(), bce_loss.mean(),
                        sparsity_reward.mean(), self.configuration['model']
                        ['generator']['sparsity_lambda']))
                n_iters = total_iter * epoch + idx
                sys.stdout.flush()
            loss_total += float(loss.data.cpu().item())

        predictions = [x for y in predictions for x in y]

        return loss_total * bsize / N, predictions
示例#34
0
 def pdf(self, f, y):
     sigmoid = torch.nn.Sigmoid()
     p = sigmoid(f).flatten()
     bernoulli = Ber(probs=p)
     pdf = torch.exp(bernoulli.log_prob(y))
     return pdf
示例#35
0

prelogits = torch.zeros([B,C])
logits = prelogits - logsumexp(prelogits)
# logits = torch.tensor(logits.clone().detach(), requires_grad=True)
logits.requires_grad_(True)



grads = []
for i in range(N):
    dist1 = Bernoulli(logits=logits[:,0])

    # Decision 1
    b1 = dist1.sample()
    logprob1 = dist1.log_prob(b1)

    if b1 ==0:
        dist2 = Bernoulli(logits=logits[:,1])
    else:
        dist2 = Bernoulli(logits=logits[:,2])

    # Decision 2
    b2 = dist2.sample()
    logprob2 = dist2.log_prob(b2)

    if b1 == 0 and b2 == 0:
        reward = 1
    elif b1 == 0 and b2 == 1:
        reward = 2
    elif b1 == 1 and b2 == 0: