Example #1
0
def sample_gmm(batch_size, mixture_weights):
    cat = Categorical(probs=mixture_weights)
    cluster = cat.sample([batch_size]) # [B]
    mean = (cluster*10.).float().cuda()
    std = torch.ones([batch_size]).cuda() *5.
    norm = Normal(mean, std)
    samp = norm.sample()
    samp = samp.view(batch_size, 1)
    return samp
Example #2
0
def sample_true2():
    cat = Categorical(probs= torch.tensor(true_mixture_weights))
    cluster = cat.sample()
    # print (cluster)
    # fsd
    norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
    samp = norm.sample()
    # print (samp)
    return samp,cluster
Example #3
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by `probs`.

    Samples are one-hot coded vectors of size probs.size(-1).

    See also: :func:`torch.distributions.Categorical`

    Example::

        >>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
         0
         0
         1
         0
        [torch.FloatTensor of size 4]

    Args:
        probs (Tensor or Variable): event probabilities
    """
    params = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.probs.size()[:-1]
        event_shape = self._categorical.probs.size()[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape)

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        probs = self._categorical.probs
        n = self.event_shape[0]
        if isinstance(probs, Variable):
            values = Variable(torch.eye(n, out=probs.data.new(n, n)))
        else:
            values = torch.eye(n, out=probs.new(n, n))
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
Example #4
0
    def test_gmm_loss(self):
        """ Test case 1 """
        n_samples = 10000

        means = torch.Tensor([[0., 0.],
                              [1., 1.],
                              [-1., 1.]])
        stds = torch.Tensor([[.03, .05],
                             [.02, .1],
                             [.1, .03]])
        pi = torch.Tensor([.2, .3, .5])

        cat_dist = Categorical(pi)
        indices = cat_dist.sample((n_samples,)).long()
        rands = torch.randn(n_samples, 2)

        samples = means[indices] + rands * stds[indices]

        class _model(nn.Module):
            def __init__(self, gaussians):
                super().__init__()
                self.means = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_())
                self.pre_stds = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_())
                self.pi = nn.Parameter(torch.Tensor(1, gaussians).normal_())

            def forward(self, *inputs):
                return self.means, torch.exp(self.pre_stds), f.softmax(self.pi, dim=1)

        model = _model(3)
        optimizer = torch.optim.Adam(model.parameters())

        iterations = 100000
        log_step = iterations // 10
        pbar = tqdm(total=iterations)
        cum_loss = 0
        for i in range(iterations):
            batch = samples[torch.LongTensor(128).random_(0, n_samples)]
            m, s, p = model.forward()
            loss = gmm_loss(batch, m, s, p)
            cum_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix_str("avg_loss={:10.6f}".format(
                cum_loss / (i + 1)))
            pbar.update(1)
            if i % log_step == log_step - 1:
                print(m)
                print(s)
                print(p)
Example #5
0
def sample_true(batch_size):
    # print (true_mixture_weights.shape)
    cat = Categorical(probs=torch.tensor(true_mixture_weights))
    cluster = cat.sample([batch_size]) # [B]
    mean = (cluster*10.).float()
    std = torch.ones([batch_size]) *5.
    # print (cluster.shape)
    # fsd
    # norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
    norm = Normal(mean, std)
    samp = norm.sample()
    # print (samp.shape)
    # fadsf
    samp = samp.view(batch_size, 1)
    return samp
Example #6
0
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)
    outputs = {}

    cat = Categorical(probs=probs)

    grads =[]
    # net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
        outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]

        surr_pred = surrogate.net(x)

        outputs['f'] = f = logpxz - logq - 1. 
        # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
        outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)

        if get_grad:
            grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
            grads.append(grad)

    # net_loss = net_loss/ k

    if get_grad:
        grads = torch.stack(grads)
        # print (grads.shape)
        outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
        outputs['grad_std'] = torch.std(grads, dim=0)[0]

    outputs['surr_loss'] = surr_loss
    # return net_loss, f, logpx_given_z, logpz, logq
    return outputs
Example #7
0
def reinforce(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = Categorical(probs=probs)

    net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        logq = cat.log_prob(cluster_H).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq
        net_loss += - torch.mean((f.detach() - 1.) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq




#REINFORCE
print ('REINFORCE')

# def sample_reinforce_given_class(logits, samp):    
#     return logprob

grads = []
for i in range (N):

    dist = Categorical(logits=logits)
    samp = dist.sample()
    logprob = dist.log_prob(samp)
    reward = f(samp) 
    gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0]
    grads.append(reward*gradlogprob)
    
print ()
grads = torch.stack(grads).view(N,C)
# print (grads.shape)
grad_mean_reinforce = torch.mean(grads,dim=0)
grad_std_reinforce = torch.std(grads,dim=0)

print ('REINFORCE')
print ('mean:', grad_mean_reinforce)
print ('std:', grad_std_reinforce)
print ()
Example #9
0

def G(rewards, start=0, end=None):
    return sum(rewards[start:end])


if __name__ == "__main__":

    for episode in range(NUM_EPISODES):
        s, done = env.reset(), False
        states, rewards, log_probs = [], [], []

        while not done:
            s = torch.from_numpy(s).float()
            p = Categorical(actor(s))
            a = p.sample()
            with torch.no_grad():
                succ, r, done, _ = env.step(a.numpy())

            states.append(s)
            rewards.append(r)
            log_probs.append(p.log_prob(a))

            s = succ

        discounted_rewards = [DISCOUNT**t * r for t, r in enumerate(rewards)]
        cumulative_returns = [
            G(discounted_rewards, t) for t, _ in enumerate(discounted_rewards)
        ]

        states = torch.stack(states)
def simplax():



    def show_surr_preds():

        batch_size = 1

        rows = 3
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        for i in range(rows):

            x = sample_true(1).cuda() #.view(1,1)
            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1)
            cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            check_nan(logprob_cluster)

            z = cluster_S

            n_evals = 40
            x1 = np.linspace(-9,205, n_evals)
            x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
            z = z.repeat(n_evals,1)
            cluster_H = cluster_H.repeat(n_evals,1)
            xz = torch.cat([z,x], dim=1) 

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz #- logprob_cluster

            surr_pred = surrogate.net(xz)
            surr_pred = surr_pred.data.cpu().numpy()
            f = f.data.cpu().numpy()

            col =0
            row = i
            # print (row)
            ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

            ax.plot(x1,surr_pred, label='Surr')
            ax.plot(x1,f, label='f')
            ax.set_title(str(cluster_H[0]))
            ax.legend()


        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_surr.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()




    def plot_dist():


        mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

        rows = 1
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


        xs = np.linspace(-9,205, 300)
        sum_ = np.zeros(len(xs))

        # C = 20
        for c in range(n_components):
            m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
            ys = []
            for x in xs:
                # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
                component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()


                ys.append(component_i)

            ys = np.reshape(np.array(ys), [-1])
            sum_ += ys
            ax.plot(xs, ys, label='')

        ax.plot(xs, sum_, label='')

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_plot_dist.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        


    def get_loss():

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)
        
        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(logpxz.detach()-surr_pred))

        return surr_loss


    def plot_posteriors(needsoftmax_mixtureweight, name=''):

        x = sample_true(1).cuda() 
        trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1).view(n_components)

        trueposterior = trueposterior.data.cpu().numpy()
        qz = probs.data.cpu().numpy()

        error = L2_mixtureweights(trueposterior,qz)
        kl = KL_mixutreweights(p=trueposterior, q=qz)


        rows = 1
        cols = 1
        fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

        width = .3
        ax.bar(range(len(qz)), trueposterior, width=width, label='True')
        ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q')
        # ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width)
        ax.legend()
        ax.grid(True, alpha=.3)
        ax.set_title(str(error) + ' kl:' + str(kl))
        ax.set_ylim(0.,1.)

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'posteriors'+name+'.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        



    def inference_error(needsoftmax_mixtureweight):

        error_sum = 0
        kl_sum = 0
        n=10
        for i in range(n):

            # if x is None:
            x = sample_true(1).cuda() 
            trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1).view(n_components)

            error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
            kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy())

            error_sum+=error
            kl_sum += kl
        
        return error_sum/n, kl_sum/n
        # fsdfa



    #SIMPLAX
    needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")#.cuda()
    
    print ('current mixuture weights')
    print (torch.softmax(needsoftmax_mixtureweight, dim=0))
    print()

    encoder = NN3(input_size=1, output_size=n_components).cuda()
    surrogate = NN3(input_size=1+n_components, output_size=1).cuda()
    # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.00004)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0004)
    # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.004)
    # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.0001)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0001)
    optim_net = torch.optim.SGD(encoder.parameters(), lr=.0001)
    # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.005)
    temp = 1.
    batch_size = 100
    n_steps = 300000
    surrugate_steps = 0
    k = 1
    L2_losses = []
    inf_losses = []
    inf_losses_kl = []
    kl_losses_2 = []
    surr_losses = []
    steps_list =[]
    grad_reparam_list =[]
    grad_reinforce_list =[]
    f_list = []
    logpxz_list = []
    logprob_cluster_list = []
    logpx_list = []
    # logprob_cluster_list = []
    for step in range(n_steps):

        for ii in range(surrugate_steps):
            surr_loss = get_loss()
            optim_surr.zero_grad()
            surr_loss.backward()
            optim_surr.step()

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        # print (probs)
        # fsdafsa
        # cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())

        cat = Categorical(probs=probs)
        # cluster = cat.sample()
        # logprob_cluster = cat.log_prob(cluster.detach())

        net_loss = 0
        loss = 0
        surr_loss = 0
        for jj in range(k):

            # cluster_S = cat.rsample()
            # cluster_H = H(cluster_S)
            cluster_H = cat.sample()

            # print (cluster_H.shape)
            # print (cluster_H[0])
            # fsad


            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.
            # logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            logprob_cluster = cat.log_prob(cluster_H.detach()).view(batch_size,1)
            # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1).detach()

            # cat = RelaxedOneHotCategorical(probs=probs.detach(), temperature=torch.tensor([temp]).cuda())
            # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1)

            check_nan(logprob_cluster)

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz - logprob_cluster
            # print (f)

            # surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
            surr_input = torch.cat([probs, x], dim=1) #[B,21]
            surr_pred = surrogate.net(surr_input)
            
            # print (f.shape)
            # print (surr_pred.shape)
            # print (logprob_cluster.shape)
            # fsadfsa
            # net_loss += - torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster + surr_pred - logprob_cluster)
            # net_loss += - torch.mean((logpxz.detach() - surr_pred.detach() - 1.) * logprob_cluster + surr_pred)
            net_loss += - torch.mean((logpxz.detach() - 1.) * logprob_cluster)
            # net_loss += - torch.mean((logpxz.detach()) * logprob_cluster - logprob_cluster)
            loss += - torch.mean(logpxz)
            surr_loss += torch.mean(torch.abs(logpxz.detach()-surr_pred))

        net_loss = net_loss/ k
        loss = loss / k
        surr_loss = surr_loss/ k



        # if step %2==0:
        # optim.zero_grad()
        # loss.backward(retain_graph=True)  
        # optim.step()

        optim_net.zero_grad()
        net_loss.backward(retain_graph=True)
        optim_net.step()

        # optim_surr.zero_grad()
        # surr_loss.backward(retain_graph=True)
        # optim_surr.step()

        # print (torch.mean(f).cpu().data.numpy())
        # plot_posteriors(name=str(step))
        # fsdf

        # kl_batch = compute_kl_batch(x,probs)


        if step%500==0:
            print (step, 'f:', torch.mean(f).cpu().data.numpy(), 'surr_loss:', surr_loss.cpu().data.detach().numpy(), 
                            'theta dif:', L2_mixtureweights(true_mixture_weights,torch.softmax(
                                        needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
            # if step %5000==0:
            #     print (torch.softmax(needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()) 
            #     # test_samp, test_cluster = sample_true2() 
            #     # print (test_cluster.cpu().data.numpy(), test_samp.cpu().data.numpy(), torch.softmax(encoder.net(test_samp.cuda().view(1,1)), dim=1))           
            #     print ()

            if step > 0:
                L2_losses.append(L2_mixtureweights(true_mixture_weights,torch.softmax(
                                            needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
                steps_list.append(step)
                surr_losses.append(surr_loss.cpu().data.detach().numpy())

                inf_error, kl_error = inference_error(needsoftmax_mixtureweight)
                inf_losses.append(inf_error)
                inf_losses_kl.append(kl_error)

                kl_batch = compute_kl_batch(x,probs,needsoftmax_mixtureweight)
                kl_losses_2.append(kl_batch)

                logpx = copmute_logpx(x, needsoftmax_mixtureweight)
                logpx_list.append(logpx)

                f_list.append(torch.mean(f).cpu().data.detach().numpy())
                logpxz_list.append(torch.mean(logpxz).cpu().data.detach().numpy())
                logprob_cluster_list.append(torch.mean(logprob_cluster).cpu().data.detach().numpy())




                # i_feel_like_it = 1
                # if i_feel_like_it:

                if len(inf_losses) > 0:
                    print ('probs', probs[0])
                    print('logpxz', logpxz[0])
                    print('pred', surr_pred[0])
                    print ('dif', logpxz.detach()[0]-surr_pred.detach()[0])
                    print ('logq', logprob_cluster[0])
                    print ('dif*logq', (logpxz.detach()[0]-surr_pred.detach()[0])*logprob_cluster[0])
                    
                    


                    output= torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster, dim=0)[0] 
                    output2 = torch.mean(surr_pred, dim=0)[0]
                    output3 = torch.mean(logprob_cluster, dim=0)[0]
                    # input_ = torch.mean(probs, dim=0) #[0]
                    # print (probs.shape)
                    # print (output.shape)
                    # print (input_.shape)
                    grad_reinforce = torch.autograd.grad(outputs=output, inputs=(probs), retain_graph=True)[0]
                    grad_reparam = torch.autograd.grad(outputs=output2, inputs=(probs), retain_graph=True)[0]
                    grad3 = torch.autograd.grad(outputs=output3, inputs=(probs), retain_graph=True)[0]
                    # print (grad)
                    # print (grad_reinforce.shape)
                    # print (grad_reparam.shape)
                    grad_reinforce = torch.mean(torch.abs(grad_reinforce))
                    grad_reparam = torch.mean(torch.abs(grad_reparam))
                    grad3 = torch.mean(torch.abs(grad3))
                    # print (grad_reinforce)
                    # print (grad_reparam)
                    # dfsfda
                    grad_reparam_list.append(grad_reparam.cpu().data.detach().numpy())
                    grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())
                    # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())

                    print ('reparam:', grad_reparam.cpu().data.detach().numpy())
                    print ('reinforce:', grad_reinforce.cpu().data.detach().numpy())
                    print ('logqz grad:', grad3.cpu().data.detach().numpy())

                    print ('current mixuture weights')
                    print (torch.softmax(needsoftmax_mixtureweight, dim=0))
                    print()

                    # print ()
                else:
                    grad_reparam_list.append(0.)
                    grad_reinforce_list.append(0.)                    



            if len(surr_losses) > 3  and step %1000==0:
                plot_curve(steps=steps_list,  thetaloss=L2_losses, 
                            infloss=inf_losses, surrloss=surr_losses,
                            grad_reinforce_list=grad_reinforce_list, 
                            grad_reparam_list=grad_reparam_list,
                            f_list=f_list, logpxz_list=logpxz_list,
                            logprob_cluster_list=logprob_cluster_list,
                            inf_losses_kl=inf_losses_kl,
                            kl_losses_2=kl_losses_2,
                            logpx_list=logpx_list)


                plot_posteriors(needsoftmax_mixtureweight)
                plot_dist()
                show_surr_preds()
                

            # print (f)
            # print (surr_pred)

            #Understand surr preds
            # if step %5000==0:

            # if step ==0:
                
                # fasdf
                





    data_dict = {}

    data_dict['steps'] = steps_list
    data_dict['losses'] = L2_losses

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    with open( exp_dir+"data_simplax.p", "wb" ) as f:
        pickle.dump(data_dict, f)
    print ('saved data')
Example #11
0
def valor(args):
    if not hasattr(args, "get"):
        args.get = args.__dict__.get
    env_fn = args.get('env_fn', lambda: gym.make('HalfCheetah-v2'))
    actor_critic = args.get('actor_critic', ActorCritic)
    ac_kwargs = args.get('ac_kwargs', {})
    disc = args.get('disc', Discriminator)
    dc_kwargs = args.get('dc_kwargs', {})
    seed = args.get('seed', 0)
    episodes_per_epoch = args.get('episodes_per_epoch', 40)
    epochs = args.get('epochs', 50)
    gamma = args.get('gamma', 0.99)
    pi_lr = args.get('pi_lr', 3e-4)
    vf_lr = args.get('vf_lr', 1e-3)
    dc_lr = args.get('dc_lr', 2e-3)
    train_v_iters = args.get('train_v_iters', 80)
    train_dc_iters = args.get('train_dc_iters', 50)
    train_dc_interv = args.get('train_dc_interv', 2)
    lam = args.get('lam', 0.97)
    max_ep_len = args.get('max_ep_len', 1000)
    logger_kwargs = args.get('logger_kwargs', {})
    context_dim = args.get('context_dim', 4)
    max_context_dim = args.get('max_context_dim', 64)
    save_freq = args.get('save_freq', 10)
    k = args.get('k', 1)

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    ac_kwargs['action_space'] = env.action_space

    # Model
    actor_critic = actor_critic(input_dim=obs_dim[0] + max_context_dim,
                                **ac_kwargs)
    disc = disc(input_dim=obs_dim[0], context_dim=max_context_dim, **dc_kwargs)

    # Buffer
    local_episodes_per_epoch = episodes_per_epoch  # int(episodes_per_epoch / num_procs())
    buffer = Buffer(max_context_dim, obs_dim[0], act_dim[0],
                    local_episodes_per_epoch, max_ep_len, train_dc_interv)

    # Count variables
    var_counts = tuple(
        count_vars(module)
        for module in [actor_critic.policy, actor_critic.value_f, disc.policy])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n' %
               var_counts)

    # Optimizers
    #Optimizer for RL Policy
    train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)

    #Optimizer for value function (for actor-critic)
    train_v = torch.optim.Adam(actor_critic.value_f.parameters(), lr=vf_lr)

    #Optimizer for decoder
    train_dc = torch.optim.Adam(disc.policy.parameters(), lr=dc_lr)

    #pdb.set_trace()

    # Parameters Sync
    #sync_all_params(actor_critic.parameters())
    #sync_all_params(disc.parameters())
    '''
    Training function
    '''
    def update(e):
        obs, act, adv, pos, ret, logp_old = [
            torch.Tensor(x) for x in buffer.retrieve_all()
        ]

        # Policy
        #pdb.set_trace()
        _, logp, _ = actor_critic.policy(obs, act, batch=False)
        #pdb.set_trace()
        entropy = (-logp).mean()

        # Policy loss
        pi_loss = -(logp * (k * adv + pos)).mean()

        # Train policy (Go through policy update)
        train_pi.zero_grad()
        pi_loss.backward()
        # average_gradients(train_pi.param_groups)
        train_pi.step()

        # Value function
        v = actor_critic.value_f(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            v = actor_critic.value_f(obs)
            v_loss = F.mse_loss(v, ret)

            # Value function train
            train_v.zero_grad()
            v_loss.backward()
            # average_gradients(train_v.param_groups)
            train_v.step()

        # Discriminator
        if (e + 1) % train_dc_interv == 0:
            print('Discriminator Update!')
            con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()]
            _, logp_dc, _ = disc(s_diff, con)
            d_l_old = -logp_dc.mean()

            # Discriminator train
            for _ in range(train_dc_iters):
                _, logp_dc, _ = disc(s_diff, con)
                d_loss = -logp_dc.mean()
                train_dc.zero_grad()
                d_loss.backward()
                # average_gradients(train_dc.param_groups)
                train_dc.step()

            _, logp_dc, _ = disc(s_diff, con)
            dc_l_new = -logp_dc.mean()
        else:
            d_l_old = 0
            dc_l_new = 0

        # Log the changes
        _, logp, _, v = actor_critic(obs, act)
        pi_l_new = -(logp * (k * adv + pos)).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()
        logger.store(LossPi=pi_loss,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=entropy,
                     DeltaLossPi=(pi_l_new - pi_loss),
                     DeltaLossV=(v_l_new - v_l_old),
                     LossDC=d_l_old,
                     DeltaLossDC=(dc_l_new - d_l_old))
        # logger.store(Adv=adv.reshape(-1).numpy().tolist(), Pos=pos.reshape(-1).numpy().tolist())

    start_time = time.time()
    #Resets observations, rewards, done boolean
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    #Creates context distribution where each logit is equal to one (This is first place to make change)
    context_dim_prob_dict = {
        i: 1 / context_dim if i < context_dim else 0
        for i in range(max_context_dim)
    }
    last_phi_dict = {i: 0 for i in range(context_dim)}
    context_dist = Categorical(
        probs=torch.Tensor(list(context_dim_prob_dict.values())))
    total_t = 0

    for epoch in range(epochs):
        #Sets actor critic and decoder (discriminator) into eval mode
        actor_critic.eval()
        disc.eval()

        #Runs the policy local_episodes_per_epoch before updating the policy
        for index in range(local_episodes_per_epoch):
            # Sample from context distribution and one-hot encode it (Step 2)
            # Every time we run the policy we sample a new context

            c = context_dist.sample()
            c_onehot = F.one_hot(c, max_context_dim).squeeze().float()
            for _ in range(max_ep_len):
                concat_obs = torch.cat(
                    [torch.Tensor(o.reshape(1, -1)),
                     c_onehot.reshape(1, -1)], 1)
                '''
                Feeds in observation and context into actor_critic which spits out a distribution 
                Label is a sample from the observation
                pi is the action sampled
                logp is the log probability of some other action a
                logp_pi is the log probability of pi 
                v_t is the value function
                '''
                a, _, logp_t, v_t = actor_critic(concat_obs)

                #Stores context and all other info about the state in the buffer
                buffer.store(c,
                             concat_obs.squeeze().detach().numpy(),
                             a.detach().numpy(), r, v_t.item(),
                             logp_t.detach().numpy())
                logger.store(VVals=v_t)

                o, r, d, _ = env.step(a.detach().numpy()[0])
                ep_ret += r
                ep_len += 1
                total_t += 1

                terminal = d or (ep_len == max_ep_len)
                if terminal:
                    # Key stuff with discriminator
                    dc_diff = torch.Tensor(buffer.calc_diff()).unsqueeze(0)
                    #Context
                    con = torch.Tensor([float(c)]).unsqueeze(0)
                    #Feed in differences between each state in your trajectory and a specific context
                    #Here, this is just the log probability of the label it thinks it is
                    _, _, log_p = disc(dc_diff, con)
                    buffer.end_episode(log_p.detach().numpy())
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, [actor_critic, disc], None)

        # Sets actor_critic and discriminator into training mode
        actor_critic.train()
        disc.train()

        update(epoch)
        #Need to implement curriculum learning here to update context distribution
        ''' 
            #Psuedocode:
            Loop through each of d episodes taken in local_episodes_per_epoch and check log probability from discrimantor
            If >= 0.86, increase k in the following manner: k = min(int(1.5*k + 1), Kmax)
            Kmax = 64
        '''

        decoder_accs = []
        stag_num = 10
        stag_pct = 0.05

        if (epoch + 1) % train_dc_interv == 0 and epoch > 0:
            #pdb.set_trace()
            con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()]
            print("Context: ", con)
            print("num_contexts", len(con))
            _, logp_dc, _ = disc(s_diff, con)
            log_p_context_sample = logp_dc.mean().detach().numpy()

            print("Log Probability context sample", log_p_context_sample)

            decoder_accuracy = np.exp(log_p_context_sample)
            print("Decoder Accuracy", decoder_accuracy)

            logger.store(LogProbabilityContext=log_p_context_sample,
                         DecoderAccuracy=decoder_accuracy)
            '''
            Create score (phi(i)) = -log_p_context_sample.mean() for each specific context 
            Assign phis to each specific context
            Get p(i) in the following manner: (phi(i) + epsilon)
            Get Probabilities by doing p(i)/sum of all p(i)'s 
            '''
            logp_np = logp_dc.detach().numpy()
            con_np = con.detach().numpy()
            phi_dict = {i: 0 for i in range(context_dim)}
            count_dict = {i: 0 for i in range(context_dim)}
            for i in range(len(logp_np)):
                current_con = con_np[i]
                phi_dict[current_con] += logp_np[i]
                count_dict[current_con] += 1
            print(phi_dict)

            phi_dict = {
                k: last_phi_dict[k] if count_dict[k] == 0 else
                (-1) * v / count_dict[k]
                for (k, v) in phi_dict.items()
            }
            sorted_dict = dict(
                sorted(phi_dict.items(),
                       key=lambda item: item[1],
                       reverse=True))
            sorted_dict_keys = list(sorted_dict.keys())
            rank_dict = {
                sorted_dict_keys[i]: 1 / (i + 1)
                for i in range(len(sorted_dict_keys))
            }
            rank_dict_sum = sum(list(rank_dict.values()))
            context_dim_prob_dict = {
                k: rank_dict[k] / rank_dict_sum if k < context_dim else 0
                for k in context_dim_prob_dict.keys()
            }
            print(context_dim_prob_dict)

            decoder_accs.append(decoder_accuracy)
            stagnated = (len(decoder_accs) > stag_num
                         and (decoder_accs[-stag_num - 1] - decoder_accuracy) /
                         stag_num < stag_pct)
            if stagnated:
                new_context_dim = max(int(0.75 * context_dim), 5)
            elif decoder_accuracy >= 0.86:
                new_context_dim = min(int(1.5 * context_dim + 1),
                                      max_context_dim)
            if stagnated or decoder_accuracy >= 0.86:
                print("new_context_dim: ", new_context_dim)
                new_context_prob_arr = np.array(
                    new_context_dim * [1 / new_context_dim] +
                    (max_context_dim - new_context_dim) * [0])
                context_dist = Categorical(
                    probs=ptu.from_numpy(new_context_prob_arr))
                context_dim = new_context_dim

            for i in range(context_dim):
                if i in phi_dict:
                    last_phi_dict[i] = phi_dict[i]
                elif i not in last_phi_dict:
                    last_phi_dict[i] = max(phi_dict.values())

            buffer.clear_dc_buff()
        else:
            logger.store(LogProbabilityContext=0, DecoderAccuracy=0)

        # Log
        logger.store(ContextDim=context_dim)
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', total_t)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('LossDC', average_only=True)
        logger.log_tabular('DeltaLossDC', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.log_tabular('LogProbabilityContext', average_only=True)
        logger.log_tabular('DecoderAccuracy', average_only=True)
        logger.log_tabular('ContextDim', average_only=True)
        logger.dump_tabular()
Example #12
0
    def run(self, episodes, steps, train=False, render_once=1e10, saveonce=10):
        if train:
            assert self.recorder.log_message is not None, "log_message is necessary during training, Instantiate Runner with log message"

        reset_model = False
        if hasattr(self.model, "type") and self.model.type == "mem":
            print("Recurrent Model")
            reset_model = True
        assert not hasattr(self.model,
                           "hidden_states"), "no hidden_states list attribute"
        self.env.display_neural_image = self.visual_activations

        for _ in range(episodes):
            self.episode_rewards = []
            self.env.reset()
            self.env.enable_draw = True if not train or _ % render_once == render_once - 1 else False

            if reset_model:
                self.model.reset()

            state = self.env.get_state().reshape(-1)
            bar = tqdm(range(steps),
                       bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
            trewards = 0

            for step in bar:
                # if self.env.game_done:
                #     break
                state = T.from_numpy(state).float()
                # print(state)
                actions = self.model(state).view(-1)
                # print(actions)
                c = Categorical(actions)
                action = c.sample()
                prob = c.probs[action]

                # print(actions,prob)
                u = np.zeros(self.nactions)
                u[action] = 1.0
                newstate, reward = self.env.act(u)
                trewards += reward
                self.episode_rewards.append(trewards)
                if train:
                    if self.model.type == "mem":
                        self.trainer.store_records(
                            (state.tolist(), action, reward, prob,
                             self.model.hidden_states[-2], False))
                    else:
                        self.trainer.store_records(
                            (state.tolist(), action, reward, prob, [], False))
                    # self.trainer.store_records((state.tolist(),action,reward, c.log_prob(action),self.model.hidden_states[-2], False))

                state = newstate.reshape(-1)
                if self.model.type == "mem" and self.visual_activations:
                    u = T.cat(self.activations, dim=0).reshape(-1)
                    self.neural_image_values = u.detach().numpy()
                    self.activations = []
                    if _ % 10 == 0 and step / steps == 0:
                        self.update_weights()
                        self.neural_weights = self.weights
                        self.weight_change = True
                    if type(self.model.hidden_vectors) != type(None):
                        self.hidden_state = self.model.hidden_vectors
                else:
                    self.activations = []

                bar.set_description(f"Episode: {_:4} Rewards : {trewards}")
                if train:
                    self.env.step()
                else:
                    self.env.step(speed=0)

                self.event_handler()
                self.window.fill((0, 0, 0))
                if self.visual_activations and (not train or _ % render_once
                                                == render_once - 1):
                    if self.model.type == "mem":
                        self.draw_neural_image()
                    self.window.blit(self.env.win, (0, 0))

            if train:
                self.trainer.update()
                self.recorder.newdata(trewards)
                if _ % saveonce == saveonce - 1:
                    self.recorder.save()
                    self.recorder.plot()

                if _ % saveonce == saveonce - 1 and self.recorder.final_reward >= self.current_max_reward:
                    self.recorder.save_model(self.model)
                    self.current_max_reward = self.recorder.final_reward
        print("******* Run Complete *******")
Example #13
0
    def decode_one_batch_rl(self, greedy, batch, s_t_1, c_t_1, enc_outputs,
                            enc_features, enc_padding_mask, extend_vocab_zeros,
                            enc_batch_extended, coverage_t, device):
        # No teacher forcing for RL
        dec_batch, _, max_dec_len, dec_lens_var, target_batch = get_output_from_batch(
            self.params, batch, device)
        log_probs = []
        decode_ids = []
        # we create the dec_padding_mask at the runtime
        dec_padding_mask = []
        y_t = dec_batch[:, 0]
        mask_t = torch.ones(len(enc_outputs), dtype=torch.long, device=device)
        # there is at least one token in the decoded seqs, which is STOP_DECODING
        for di in range(min(max_dec_len, self.params.max_dec_steps)):
            y_t_1 = y_t
            # first we have coverage_t_1, then we have a_t
            final_dist, s_t_1, c_t_1, attn_dist, coverage_t_plus = self.model.decoder(
                y_t_1, s_t_1, c_t_1, enc_outputs, enc_features,
                enc_padding_mask, extend_vocab_zeros, enc_batch_extended,
                coverage_t)
            if not greedy:
                # sampling
                multi_dist = Categorical(final_dist)
                y_t = multi_dist.sample()
                log_prob = multi_dist.log_prob(y_t)
                log_probs.append(log_prob)

                y_t = y_t.detach()
                dec_padding_mask.append(mask_t.detach().clone())
                mask_t[(mask_t == 1) + (y_t == self.end_id) == 2] = 0
            else:
                # baseline
                y_t = final_dist.max(1)[1]
                y_t = y_t.detach()

            decode_ids.append(y_t)
            # for next input
            is_oov = (y_t >= self.vocab.size()).long()
            y_t = (1 - is_oov) * y_t + is_oov * self.unk_id

        decode_ids = torch.stack(decode_ids, 1)

        if not greedy:
            dec_padding_mask = torch.stack(dec_padding_mask, 1).float()
            log_probs = torch.stack(log_probs, 1) * dec_padding_mask
            dec_lens = dec_padding_mask.sum(1)
            log_probs = log_probs.sum(1) / dec_lens
            if (dec_lens == 0).any():
                print("Decode lengths encounter zero!")
                print(dec_lens)

        decoded_seqs = []
        for i in range(len(enc_outputs)):
            dec_ids = decode_ids[i].cpu().numpy()
            article_oovs = batch.art_oovs[i]
            dec_words = data.outputids2decwords(dec_ids, self.vocab,
                                                article_oovs,
                                                self.params.pointer_gen)
            if len(dec_words) < 2:
                dec_seq = "xxx"
            else:
                dec_seq = " ".join(dec_words)
            decoded_seqs.append(dec_seq)

        return decoded_seqs, log_probs
Example #14
0
def reinforce(n_components, needsoftmax_mixtureweight=None):

    if needsoftmax_mixtureweight is None:
        needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")
    else:
        needsoftmax_mixtureweight = torch.tensor(needsoftmax_mixtureweight, 
                                            requires_grad=True, device="cuda")

    mixtureweights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() #[C]
    
    encoder = NN3(input_size=1, output_size=n_components, n_residual_blocks=2).cuda()
    # optim_net = torch.optim.SGD(encoder.parameters(), lr=1e-4, weight_decay=1e-7)
    optim_net = torch.optim.Adam(encoder.parameters(), lr=1e-5, weight_decay=1e-7)




    batch_size = 10
    n_steps = 300000
    k = 1

    data_dict = {}
    data_dict['steps'] = []
    data_dict['losses'] = []


    # needsoftmax_qprobs = torch.randn((1,n_components), requires_grad=True, device="cuda")
    # optim_net = torch.optim.SGD([needsoftmax_qprobs], lr=1e-3, weight_decay=1e-7)

    # probs = torch.softmax(needsoftmax_qprobs, dim=1)
    # print ('probs:', to_print2(probs))

    # x = sample_gmm(batch_size, mixture_weights=mixtureweights)

    # count = np.zeros(3)

    for step in range(n_steps):

        x = sample_gmm(batch_size, mixture_weights=mixtureweights)
        logits = encoder.net(x)
        # logits = needsoftmax_qprobs
        # print (logits.shape)
        # fdsfd
        probs = torch.softmax(logits, dim=1)
        # print (probs)
        # print (torch.log(probs))
        # print (torch.softmax(torch.log(probs), dim=1))
        
        # print (probs.shape)
        # print (probs)
        # probs = probs.repeat(batch_size, 1)
        cat = Categorical(probs=probs)

        net_loss = 0
        for jj in range(k):

            cluster_H = cat.sample()

            # c_ = cluster_H.data.cpu().numpy()[0]
            # count[c_]+=1

            # print (cluster_H.shape)
            # print (cluster_H)

            # print(logits)
            # print (cluster_H)
            # print (cluster_H.shape)
            # cluster_H = torch.tensor([0,1,2]).cuda()
            # print (cluster_H.shape)

            # fsfsad
            

            # print(logits.shape)
            # print (cluster_H.shape)
            # print ()

            # tt = torch.tensor([0]).cuda() #.view(1,1)

            # print (tt.shape)


            # print (logits[tt])
            # print (logits[tt].shape)

            # aa = torch.index_select(logits, 1, tt)
            # print (aa.shape)
            # print (aa)
            # print ()

            # aa = torch.index_select(logits, 1, cluster_H)
            # print (aa.shape)
            # print (aa)


            # sfad


            # print (logits[0])
            # fsdfas
            # print (logits[cluster_H])
            # print (logits[cluster_H].shape)
            # dsfasd



            logq = cat.log_prob(cluster_H).view(batch_size,1)
            # print (logq1.shape)
            # print (logq)
            # # fsd
            # # print (torch.log(probs))
            # # fasfd
            # # logq2 = torch.index_select(logits, 1, cluster_H) #.view(batch_size,1)
            # # logq3 = torch.log(torch.index_select(probs, 1, cluster_H) )#.view(batch_size,1)

            # grad0 = torch.autograd.grad(outputs=logq[0], inputs=(probs), retain_graph=True)[0]
            # grad1 = torch.autograd.grad(outputs=logq[1], inputs=(probs), retain_graph=True)[0]
            # grad2 = torch.autograd.grad(outputs=logq[2], inputs=(probs), retain_graph=True)[0]

            # print (grad0)
            # print (grad1)
            # print (grad2)
            # print ()
            # print (grad0*probs[0][0])
            # print (grad1*probs[0][1])
            # print (grad2*probs[0][2])
            # print ()

            # grad0 = torch.autograd.grad(outputs=logq[0], inputs=(needsoftmax_qprobs), retain_graph=True)[0]
            # grad1 = torch.autograd.grad(outputs=logq[1], inputs=(needsoftmax_qprobs), retain_graph=True)[0]
            # grad2 = torch.autograd.grad(outputs=logq[2], inputs=(needsoftmax_qprobs), retain_graph=True)[0]
            # print (grad0)
            # print (grad1)
            # print (grad2)
            # print ()
            # print (grad0*probs[0][0])
            # print (grad1*probs[0][1])
            # print (grad2*probs[0][2])
            # print ()
            # print (grad0*probs[0][0] + grad1*probs[0][1] + grad2*probs[0][2])
            # print ()
            # print ()

            # grad0 = torch.autograd.grad(outputs=logq[0].detach()*logq[0], inputs=(needsoftmax_qprobs), retain_graph=True)[0]
            # grad1 = torch.autograd.grad(outputs=logq[1].detach()*logq[1], inputs=(needsoftmax_qprobs), retain_graph=True)[0]
            # grad2 = torch.autograd.grad(outputs=logq[2].detach()*logq[2], inputs=(needsoftmax_qprobs), retain_graph=True)[0]
            # print (grad0)
            # print (grad1)
            # print (grad2)
            # print ()
            # print (grad0*probs[0][0])
            # print (grad1*probs[0][1])
            # print (grad2*probs[0][2])
            # print ()
            # print (grad0*probs[0][0] + grad1*probs[0][1] + grad2*probs[0][2])
            # print ()



            # fsfad



            # print(logq1, logq2)
            # print(logq3)
            # fsdaf

            # print (logq.shape)
            # print (logq)
            # fads
            logpx_given_z = logprob_undercomponent(x, component=cluster_H)
            logpz = torch.log(mixtureweights[cluster_H]).view(batch_size,1)
            logpxz = logpx_given_z + logpz #[B,1]
            # net_loss += - torch.mean((logpxz.detach() - 100.) * logq)
            net_loss += - torch.mean( -logq.detach()*logq)

        net_loss = net_loss/ k


        optim_net.zero_grad()
        net_loss.backward(retain_graph=True)
        optim_net.step()


        # if step%10==0:
        #     print (count/np.sum(count), probs.data.cpu().numpy())




        if step%100==0:
            # print (step, to_print(net_loss), to_print(logpxz - logq), to_print(logpx_given_z), to_print(logpz), to_print(logq))

            
            print ()
            print( 
                'S:{:5d}'.format(step),
                # 'T:{:.2f}'.format(time.time() - start_time),
                'Loss:{:.4f}'.format(to_print1(net_loss)),
                'ELBO:{:.4f}'.format(to_print1(logpxz - logq)),
                'lpx|z:{:.4f}'.format(to_print1(logpx_given_z)),
                'lpz:{:.4f}'.format(to_print1(logpz)),
                'lqz:{:.4f}'.format(to_print1(logq)),
                )

            pz_give_x = true_posterior(x, mixture_weights=mixtureweights)
            # print (pz_give_x.shape)
            # print (to_print2(x[0]), to_print2(cluster_H[0]))
            print (to_print2(probs[0]))
            # print (to_print2(torch.exp(logq[0])))
            # before = to_print2(torch.exp(logq[0]))
            # firstH = cluster_H[0]











            # logits = encoder.net(x)
            # # logits = needsoftmax_qprobs
            # probs = torch.softmax(logits, dim=1)
            # cat = Categorical(probs=probs)

            # net_loss = 0
            # for jj in range(k):

            #     cluster_H = cat.sample()

            #     logq = cat.log_prob(cluster_H).view(batch_size,1)
            #     # logq = logits[cluster_H].view(batch_size,1)

            #     logpx_given_z = logprob_undercomponent(x, component=cluster_H)
            #     logpz = torch.log(mixtureweights[cluster_H]).view(batch_size,1)
            #     logpxz = logpx_given_z + logpz #[B,1]
            #     # net_loss += - torch.mean((logpxz.detach() - 100.) * logq)
            #     net_loss += - torch.mean( -logq.detach() * logq)

            # net_loss = net_loss/ k


            # # print ()
            # print( 
            #     'S:{:5d}'.format(step),
            #     # 'T:{:.2f}'.format(time.time() - start_time),
            #     'Loss:{:.4f}'.format(to_print1(net_loss)),
            #     'ELBO:{:.4f}'.format(to_print1(logpxz - logq)),
            #     'lpx|z:{:.4f}'.format(to_print1(logpx_given_z)),
            #     'lpz:{:.4f}'.format(to_print1(logpz)),
            #     'lqz:{:.4f}'.format(to_print1(logq)),
            #     )

            # pz_give_x = true_posterior(x, mixture_weights=mixtureweights)
            # # print (pz_give_x.shape)
            # print (to_print2(x[0]), to_print2(cluster_H[0]))
            # print (to_print2(probs[0]))

            # print (to_print2(torch.exp(cat.log_prob(firstH)[0])))
            # after = to_print2(torch.exp(cat.log_prob(firstH)[0]))
            # # logq = logits[cluster_H].view(batch_size,1)

            # dif = before - after 

            # print ('Dif:', dif, 'positive is good')

            # if dif < 0:
            #     print ('howww')
            #     fafsd




            # print (to_print2(torch.exp(logq[cluster_H[0]])))
            # print (to_print2(torch.exp(cat.log_prob(torch.tensor([0]).cuda())[0]))) #, to_print2(torch.exp(logq[1])), to_print2(torch.exp(logq[2])))
            # print (to_print2(torch.exp(cat.log_prob(torch.tensor([1]).cuda())[0]))) #, to_print2(torch.exp(logq[1])), to_print2(torch.exp(logq[2])))
            # print (to_print2(torch.exp(cat.log_prob(torch.tensor([2]).cuda())[0]))) #, to_print2(torch.exp(logq[1])), to_print2(torch.exp(logq[2])))
            # print (to_print2(torch.exp(logq[0])), to_print2(torch.exp(logq[1])), to_print2(torch.exp(logq[2])))
            # print (to_print2(pz_give_x[0]))



            # print (step, 'f:', torch.mean(f).cpu().data.numpy(), 'surr_loss:', surr_loss.cpu().data.detach().numpy(), 
            #                 'theta dif:', L2_mixtureweights(true_mixture_weights,torch.softmax(
            #                             needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
            # # if step %5000==0:
            # #     print (torch.softmax(needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()) 
            # #     # test_samp, test_cluster = sample_true2() 
            # #     # print (test_cluster.cpu().data.numpy(), test_samp.cpu().data.numpy(), torch.softmax(encoder.net(test_samp.cuda().view(1,1)), dim=1))           
            # #     print ()

            # if step > 0:
            #     L2_losses.append(L2_mixtureweights(true_mixture_weights,torch.softmax(
            #                                 needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
            #     data_dict['steps'].append(step)
            #     surr_losses.append(surr_loss.cpu().data.detach().numpy())

            #     inf_error, kl_error = inference_error(needsoftmax_mixtureweight)
            #     inf_losses.append(inf_error)
            #     inf_losses_kl.append(kl_error)

            #     kl_batch = compute_kl_batch(x,probs,needsoftmax_mixtureweight)
            #     kl_losses_2.append(kl_batch)

            #     logpx = copmute_logpx(x, needsoftmax_mixtureweight)
            #     logpx_list.append(logpx)

            #     f_list.append(torch.mean(f).cpu().data.detach().numpy())
            #     logpxz_list.append(torch.mean(logpxz).cpu().data.detach().numpy())
            #     logprob_cluster_list.append(torch.mean(logprob_cluster).cpu().data.detach().numpy())




            #     # i_feel_like_it = 1
            #     # if i_feel_like_it:

            #     if len(inf_losses) > 0:
            #         print ('probs', probs[0])
            #         print('logpxz', logpxz[0])
            #         print('pred', surr_pred[0])
            #         print ('dif', logpxz.detach()[0]-surr_pred.detach()[0])
            #         print ('logq', logprob_cluster[0])
            #         print ('dif*logq', (logpxz.detach()[0]-surr_pred.detach()[0])*logprob_cluster[0])
                    
                    


            #         output= torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster, dim=0)[0] 
            #         output2 = torch.mean(surr_pred, dim=0)[0]
            #         output3 = torch.mean(logprob_cluster, dim=0)[0]
            #         # input_ = torch.mean(probs, dim=0) #[0]
            #         # print (probs.shape)
            #         # print (output.shape)
            #         # print (input_.shape)
            #         grad_reinforce = torch.autograd.grad(outputs=output, inputs=(probs), retain_graph=True)[0]
            #         grad_reparam = torch.autograd.grad(outputs=output2, inputs=(probs), retain_graph=True)[0]
            #         grad3 = torch.autograd.grad(outputs=output3, inputs=(probs), retain_graph=True)[0]
            #         # print (grad)
            #         # print (grad_reinforce.shape)
            #         # print (grad_reparam.shape)
            #         grad_reinforce = torch.mean(torch.abs(grad_reinforce))
            #         grad_reparam = torch.mean(torch.abs(grad_reparam))
            #         grad3 = torch.mean(torch.abs(grad3))
            #         # print (grad_reinforce)
            #         # print (grad_reparam)
            #         # dfsfda
            #         grad_reparam_list.append(grad_reparam.cpu().data.detach().numpy())
            #         grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())
            #         # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())

            #         print ('reparam:', grad_reparam.cpu().data.detach().numpy())
            #         print ('reinforce:', grad_reinforce.cpu().data.detach().numpy())
            #         print ('logqz grad:', grad3.cpu().data.detach().numpy())

            #         print ('current mixuture weights')
            #         print (torch.softmax(needsoftmax_mixtureweight, dim=0))
            #         print()

            #         # print ()
            #     else:
            #         grad_reparam_list.append(0.)
            #         grad_reinforce_list.append(0.)                    



            # if len(surr_losses) > 3  and step %1000==0:
            #     plot_curve(steps=steps_list,  thetaloss=L2_losses, 
            #                 infloss=inf_losses, surrloss=surr_losses,
            #                 grad_reinforce_list=grad_reinforce_list, 
            #                 grad_reparam_list=grad_reparam_list,
            #                 f_list=f_list, logpxz_list=logpxz_list,
            #                 logprob_cluster_list=logprob_cluster_list,
            #                 inf_losses_kl=inf_losses_kl,
            #                 kl_losses_2=kl_losses_2,
            #                 logpx_list=logpx_list)


            #     plot_posteriors(needsoftmax_mixtureweight)
            #     plot_dist()
            #     show_surr_preds()
                

            # print (f)
            # print (surr_pred)

            #Understand surr preds
            # if step %5000==0:

            # if step ==0:
                
                # fasdf
                


    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    with open( exp_dir+"data_simplax.p", "wb" ) as f:
        pickle.dump(data_dict, f)
    print ('saved data')
Example #15
0
model = model.to(device)

# parameter
temperature = args.temperature
softmax = nn.Softmax(0)
char2int = HPchar2int_v2()
int2char = {v: k for k, v in char2int.items()}

# inference
input_sent = '哈利走進霍格華茲'
input_ids = list(map(char2int.get, list(input_sent)))

model.eval()
with torch.no_grad():
    while len(input_ids) < args.max_len:

        input_tensor = torch.LongTensor(input_ids).unsqueeze(0).to(device)
        masks_tensor = FutureMask(input_tensor).to(device)

        outputs, _ = model(input_ids=input_tensor, input_mask=masks_tensor)

        outputs = outputs[0, -1, :] / temperature
        outputs = softmax(outputs)
        sampler = Categorical(outputs)
        input_ids.append(sampler.sample().cpu().item())

input_ids = list(map(int2char.get, input_ids))
target_chars = ''.join(input_ids)
target_chars = re.sub('\[NL\]', '\n', target_chars)

print(target_chars)
Example #16
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by :attr:`probs` or
    :attr:`logits`.

    Samples are one-hot coded vectors of size ``probs.size(-1)``.

    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
              and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
              will return this normalized value.
              The `logits` argument will be interpreted as unnormalized log probabilities
              and can therefore be any real number. It will likewise be normalized so that
              the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
              will return this normalized value.

    See also: :func:`torch.distributions.Categorical` for specifications of
    :attr:`probs` and :attr:`logits`.

    Example::

        >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
        tensor([ 0.,  0.,  0.,  1.])

    Args:
        probs (Tensor): event probabilities
        logits (Tensor): event log probabilities (unnormalized)
    """
    arg_constraints = {
        'probs': constraints.simplex,
        'logits': constraints.real_vector
    }
    support = constraints.one_hot
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(OneHotCategorical, self).__init__(batch_shape,
                                                event_shape,
                                                validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(OneHotCategorical, _instance)
        batch_shape = torch.Size(batch_shape)
        new._categorical = self._categorical.expand(batch_shape)
        super(OneHotCategorical, new).__init__(batch_shape,
                                               self.event_shape,
                                               validate_args=False)
        new._validate_args = self._validate_args
        return new

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def _param(self):
        return self._categorical._param

    @property
    def probs(self):
        return self._categorical.probs

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def mean(self):
        return self._categorical.probs

    @property
    def variance(self):
        return self._categorical.probs * (1 - self._categorical.probs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        num_events = self._categorical._num_events
        indices = self._categorical.sample(sample_shape)
        return torch.nn.functional.one_hot(indices, num_events).to(probs)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self, expand=True):
        n = self.event_shape[0]
        values = torch.eye(n,
                           dtype=self._param.dtype,
                           device=self._param.device)
        values = values.view((n, ) + (1, ) * len(self.batch_shape) + (n, ))
        if expand:
            values = values.expand((n, ) + self.batch_shape + (n, ))
        return values
Example #17
0
 def get_action_and_value(self, x, action=None):
     logits = self.actor(x)
     probs = Categorical(logits=logits)
     if action is None:
         action = probs.sample()
     return action, probs.log_prob(action), probs.entropy(), self.critic(x)
    # dist = LogitRelaxedBernoulli(torch.Tensor([1.]), bern_param)
    # dist_bernoulli = Bernoulli(bern_param)
    C= 2
    n_components = C
    B=1
    probs = torch.ones(B,C)
    bern_param = bern_param.view(B,1)
    aa = 1 - bern_param
    probs = torch.cat([aa, bern_param], dim=1)

    cat = Categorical(probs= probs)

    grads = []
    for i in range(n):
        b = cat.sample()
        logprob = cat.log_prob(b.detach())
        # b_ = torch.argmax(z, dim=1)

        logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
        grad = f(b) * logprobgrad

        grads.append(grad[0][0].data.numpy())

    print ('Grad Estimator: Reinfoce categorical')
    print ('Grad mean', np.mean(grads))
    print ('Grad std', np.std(grads))
    print ()

    reinforce_cat_grad_means.append(np.mean(grads))
    reinforce_cat_grad_stds.append(np.std(grads))
Example #19
0
class DIPTransform:
    def __init__(self,
                 image_shape,
                 num_iters,
                 stop_iters,
                 input_depth=32,
                 input_noise_std=0,
                 optimizer='adam',
                 lr=1e-2,
                 plot_every=0,
                 device='cpu'):
        """
        :param image_shape: image shape (CxHxW)
        :param num_iters: number of iterations to overfit
        :param stop_iters: dict int->float specifying categorical distribution over percentages in (0, 100] representing the probability that the transform runs for KEY/100 * NUM_ITERS iters.
        :param input_depth: depth of random input. Default value taken from paper.
        :param input_noise_std: stdev of noise added to base random input at each iter. They say in the paper that this helps sometimes (seemingly using 0.03), but default is 0 (no noise).
        :param optimizer: supported optimizers are 'adam' and 'LBFGS'
        :param lr: learning rate. Per paper and code, 1e-2 works best.
        :param plot_every: how often to save images. Doesn't save images if set to 0 (default).
        :device: 'cuda' to use GPU if available.
        For example, with NUM_ITERS = 100 and STOP_ITERS = {10: 0.5, 50: 0.3, 100: 0.2}, there is a 50% chance that the transform runs for 10 iterations, a 30% chance it runs for 50 iterations, and a 20% chance it runs for the full 100 iterations. 
        """
        self.iters, probs = zip(*stop_iters.items())
        self.probs = Categorical(torch.tensor(probs))
        self.num_iters = num_iters
        self.image_shape = image_shape
        self.input_depth = input_depth
        self.input_noise_std = input_noise_std
        self.plot_every = plot_every
        self.device = device
        self.optimizer = optimizer
        self.lr = lr
        self.loss = torch.nn.MSELoss()
        # const strings from provided DIP code
        self.opt_over = 'net'
        self.const_input = 'noise'

        # initialize network
        self.net = skip(input_depth,
                        3,
                        num_channels_down=[8, 16, 32],
                        num_channels_up=[8, 16, 32],
                        num_channels_skip=[0, 0, 4],
                        upsample_mode='bilinear',
                        need_sigmoid=True,
                        need_bias=True,
                        pad='zeros',
                        act_fun='LeakyReLU').to(self.device)

    def sample_iters(self):
        return self.iters[self.probs.sample()] * self.num_iters // 100

    def run(self, image, num_iters):
        assert image.shape[
            1:] == self.image_shape, 'Wrong shape. Expected {}, got {}.'.format(
                self.image_shape, image.shape[1:])
        # run net for num_iters iterations
        net_input = get_noise(self.input_depth, self.const_input,
                              self.image_shape[1:]).to(self.device)
        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()
        p = get_params(self.opt_over, self.net, net_input)
        #def local_closure(iter_num):
        #    self.closure(net_input_saved, image, noise, iter_num)
        lambda_closure = lambda iter_num: self.closure(net_input_saved, image,
                                                       noise, iter_num)
        optimize(self.optimizer,
                 p,
                 lambda_closure,
                 self.lr,
                 num_iters,
                 pass_iter=True)
        transformed = self.net(net_input)
        return transformed

    def closure(self, net_input_saved, image, noise, iter_num):
        net_input = net_input_saved
        if self.input_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() *
                                           self.input_noise_std)
        out = self.net(net_input)
        total_loss = self.loss(out, image)
        total_loss.backward()
        #print(total_loss)
        if self.plot_every > 0 and iter_num % self.plot_every == 0 and total_loss < 0.01:
            out_np = torch_to_np(out)
            plot_image_grid([np.clip(out_np, 0, 1)],
                            factor=4,
                            nrow=1,
                            show=False,
                            save_path=f'results_dip/imgs/{iter_num}.png')
        # maybe log loss here?

    def __call__(self, sample):
        """
        Takes in, transforms, and returns PIL image given by SAMPLE. Transformation is a random number of iterations of DIP.
        Distribution of number of iterations is specified when the transform is initialized.
        """
        torch_img = np_to_torch(tmp := pil_to_np(sample)).to(self.device)
        plot_image_grid([np.clip(tmp, 0, 1)],
                        factor=4,
                        nrow=1,
                        show=False,
                        save_path='results_dip/imgs/true.png')  # TODO remove
        num_iters = self.sample_iters()
        transformed = self.run(torch_img, num_iters)
        return np_to_pil(torch_to_np(transformed))
Example #20
0
 def sample_action(self, state):
     actions_logits = self(state)
     distribution = Categorical(logits=actions_logits)
     action = distribution.sample()
     return action
def select_actions(pi, deterministic=False):
    cate_dist = Categorical(pi)
    if deterministic:
        return torch.argmax(pi, dim=1).item()
    else:
        return cate_dist.sample().unsqueeze(-1)
Example #22
0
    L2_losses = []
    steps_list = []
    for step in range(n_steps):

        optim.zero_grad()

        loss = 0
        net_loss = 0
        for i in range(batch_size):
            x = sample_true()
            logits = encoder.net(x)
            # print (logits.shape)
            # print (torch.softmax(logits, dim=0))
            # fsfd
            cat = Categorical(probs= torch.softmax(logits, dim=0))
            cluster = cat.sample()
            logprob_cluster = cat.log_prob(cluster.detach())
            # print (logprob_cluster)
            pxz = logprob_undercomponent(x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False)
            f = pxz - logprob_cluster
            # print (f)
            # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight)
            net_loss += -f.detach() * logprob_cluster
            loss += -f
        loss = loss / batch_size
        net_loss = net_loss / batch_size

        # print (loss, net_loss)

        loss.backward(retain_graph=True)  
        optim.step()
def gen_IDENTIFY(model, task, state, turn, vocab, topk=1, target=True, var='A', \
                exclude=set()):
    # print(task)
    sim_flag = turn[1]
    turn = turn[0]
    # print('##I am here and this is the form:', state.form)

    clr_terms = set() | exclude
    cur_state = state
    while cur_state:
        form = cur_state.form
        match = re.search('[0-9]+', form)
        # print('Match before the loop:', match)
        while match != None:
            clr_terms.add(match.group(0))
            form = form[match.span()[1]:]
            match = re.search('[0-9]+', form)
        cur_state = cur_state.prev
    # print('Match after the loop:', match)
    prob_mistake = 0.0
    if turn == 'S' and state.intent == 2:
        prob_mistake = float(os.getenv('MIS_ID1'))
    elif turn == 'S' and state.intent == 6:
        prob_mistake = float(os.getenv('MIS_ID2'))
    elif turn == 'S' and state.intent == 7:
        prob_mistake = float(os.getenv('MIS_ID3'))

    sampled_val = random.random()

    if turn == 'S' and sampled_val <= prob_mistake and state.intent in [
            2, 6, 7
    ]:
        # print('^^^^^^^^^^^^^^^^^^^^^^^^^')
        # print('Mistake in generation')
        # print('^^^^^^^^^^^^^^^^^^^^^^^^^')
        print('Sampled Value: %f, Mistake Chance: %f, Previous Move: %d' %
              (sampled_val, prob_mistake, state.intent))
        task = gen_task(task, random.choice([1, 2]))

    if turn == 'S' and (state.prev == None or target == False):
        # print('FIRST TURN')
        term_probs = get_term_probs(model, task, simulation=sim_flag)
        distribution = Categorical(term_probs)
        clr_term = str(int(distribution.sample()))
        if target:
            return 'IDENTIFY(T,' + clr_term + ')'
        else:
            return 'IDENTIFY(A,' + clr_term + ')'
    elif turn == 'S' and target == True and state.intent == 2:
        if random.random() <= 0.8:
            term_probs = get_term_probs(model, task, simulation=sim_flag)
            distribution = Categorical(term_probs)
            clr_term = str(int(distribution.sample()))
            return 'IDENTIFY(T,' + clr_term + ')'
        else:
            term_probs1 = get_term_probs(model, [task[1], task[0], task[2]],
                                         simulation=sim_flag)
            term_probs2 = get_term_probs(model, [task[2], task[0], task[1]],
                                         simulation=sim_flag)
            distribution1 = Categorical(term_probs1)
            distribution2 = Categorical(term_probs2)
            clr_term1 = str(int(distribution1.sample()))
            clr_term2 = str(int(distribution2.sample()))
            return 'DISTINGUISH(T,A) AND COMPARE_REF(A,[and],B,C) AND IDENTIFY(B,' + clr_term1 + \
                ') AND IDENTIFY(C,' + clr_term2 + ')'
    elif turn == 'S' and state.intent == 7:
        return gen_IDENTIFY1(model, task, state, turn, vocab, topk=1)
    elif turn == 'S' and state.intent == 6:
        sorted_ind = np.argsort(state.distribution)[::-1]
        if 0 in sorted_ind[:2]:
            return gen_IDENTIFY1(model, task, state, turn, vocab, topk=1)
        else:
            # print('$$$$$$$$ | TERMS IN CLARIFICATION DID NOT INDENTIFY TARGET | $$$$$$$$')
            return gen_IDENTIFY2(model,
                                 task,
                                 state, (turn, sim_flag),
                                 vocab,
                                 topk=1)

    if target == True and len(task) > 1:
        ix = sample_patch(state.distribution, state)
        new_task = [task[ix]]
        for i in range(len(task)):
            if i != ix:
                new_task.append(task[i])
        task = new_task

    start_var = var if target == False else 'T'
    term_probs = get_term_probs(model, task, simulation=False, turn=turn)
    term_probs = term_probs.topk(k=20, dim=0)
    #getting the non conflicting color patch here
    # print('term probs:', term_probs)
    clr_term = str(int(term_probs.indices[0]))
    for i in range(0, 20):
        clr_term = str(int(term_probs.indices[i]))
        if clr_term not in clr_terms:
            break

    return 'IDENTIFY(' + start_var + ',' + clr_term + ')'
Example #24
0
def compute_loss_policy(pred, args, L, Lambda):
    L = L.type(dtype)
    pred_prob = softmax(pred).type(dtype)  # pred of size bs x N x 2
    d = int(args.num_nodes * args.edge_density)
    if args.batch_size == 1:
        m = Categorical(pred_prob[0, :, :])
        y_sampled = m.sample((args.num_ysampling, )).type(dtype)
        #  y of size: args.num_ysampling x N
        pred_prob_sampled_log = m.log_prob(y_sampled).type(dtype)
        # of size: args.num_ysampling x N
        pred_prob_sampled_sum_log = pred_prob_sampled_log.sum(dim=-1)
        # of size args.num_ysampling
        y_sampled_label = y_sampled * 2 - 1
        #  y of size: args.num_ysampling x N
        L = L.squeeze(0).type(dtype)
        #  L of size: N x N
        c = torch.mm(y_sampled_label, torch.mm(L, torch.t(y_sampled_label)))
        c = 1 / 4 * torch.diagonal(c, offset=0)
        # c of size args.num_ysampling
        if args.problem == 'max':
            c_plus_penalty = -c + Lambda * y_sampled_label.sum(dim=1).pow(2)
        else:
            c_plus_penalty = c + Lambda * y_sampled_label.sum(dim=1).pow(2)
        loss = pred_prob_sampled_sum_log.dot(c_plus_penalty)
        w = torch.exp(pred_prob_sampled_sum_log) / torch.exp(
            pred_prob_sampled_sum_log).sum(dim=-1)
        acc = w.dot(c)
        z = (acc / args.num_nodes - d / 4) / np.sqrt(d / 4)
        inb = torch.dot(torch.abs(y_sampled_label.sum(dim=1)), w)
    else:
        m = Categorical(pred_prob)
        y_sampled = m.sample((args.num_ysampling, )).type(dtype)
        # y_sampled of size: args.num_ysampling x bs x N
        pred_prob_sampled_log = m.log_prob(y_sampled)
        # of size: args.num_ysampling x bs x N
        y_sampled = y_sampled.permute(1, 2, 0)
        # y_sampled of size: bs x N x args.num_ysampling
        pred_prob_sampled_sum_log = pred_prob_sampled_log.sum(dim=-1).permute(
            1, 0)
        # of size args.num_ysampling x bs -> bs x args.num_ysampling
        y_sampled_label = y_sampled * 2 - 1
        c = torch.bmm(y_sampled_label.permute(0, 2, 1),
                      torch.bmm(L, y_sampled_label))
        # c of size bs x args.num_ysampling x args.num_ysampling
        c = 1 / 4 * torch.diagonal(c, offset=0, dim1=-2, dim2=-1)
        c_plus_penalty = c + Lambda * y_sampled_label.sum(dim=1).pow(2)
        # c_plus_penalty of size bs x args.num_ysampling
        loss = torch.bmm(
            c_plus_penalty.view([args.batch_size, 1, args.num_ysampling]),
            pred_prob_sampled_sum_log.view(
                [args.batch_size, args.num_ysampling, 1]))
        # loss of size bs
        loss = torch.mean(loss)
        w = torch.exp(pred_prob_sampled_sum_log) / torch.exp(
            pred_prob_sampled_sum_log).sum(dim=-1).view([args.batch_size, 1])
        acc = torch.dot(c.view([args.batch_size, 1, args.num_ysampling]),
                        w.view([args.batch_size, args.num_ysampling, 1]))
        inb = torch.dot(
            torch.abs(y_sampled_label.sum(dim=1)).view(
                [args.batch_size, 1, args.num_ysampling]),
            w.view([args.batch_size, args.num_ysampling, 1]))
        acc = torch.mean(acc)
        z = (acc / args.num_nodes - d / 4) / np.sqrt(d / 4)
        inb = torch.mean(inb)
    inb = torch.round(inb)
    return loss, acc, z, inb
Example #25
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by :attr:`probs` or
    :attr:`logits`.

    Samples are one-hot coded vectors of size ``probs.size(-1)``.

    .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
              and it will be normalized to sum to 1.

    See also: :func:`torch.distributions.Categorical` for specifications of
    :attr:`probs` and :attr:`logits`.

    Example::

        >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
        tensor([ 0.,  0.,  0.,  1.])

    Args:
        probs (Tensor): event probabilities
        logits (Tensor): event log probabilities
    """
    arg_constraints = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(OneHotCategorical, self).__init__(batch_shape,
                                                event_shape,
                                                validate_args=validate_args)

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def probs(self):
        return self._categorical.probs

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def mean(self):
        return self._categorical.probs

    @property
    def variance(self):
        return self._categorical.probs * (1 - self._categorical.probs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        n = self.event_shape[0]
        values = self._new((n, n))
        torch.eye(n, out=values)
        values = values.view((n, ) + (1, ) * len(self.batch_shape) + (n, ))
        return values.expand((n, ) + self.batch_shape + (n, ))
Example #26
0
    def forward(self, hidden_state=None):
        """
        Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use
        discrete sampling.
        Hidden state here represents the encoded image/metadata - initializes the RNN from it.
        """

        hidden_state = self.input_module(hidden_state)
        state, batch_size = self._init_state(hidden_state, type(self.rnn))

        # Init output
        if not self.vqvae and self.discrete_communication and self.rl:
            output = [
                torch.zeros(
                    (batch_size, self.vocab_size),
                    dtype=torch.float32,
                    device=self.device,
                )
            ]
            output[0][:, self.sos_id] = 1.0
        else:
            # In vqvae case, there is no sos symbol, since all words come from the unordered embedding table.
            # It is not possible to index code words by sos or eos symbols, since the number of codewords
            # is not necessarily the vocab size!
            output = [
                torch.zeros(
                    (batch_size, self.vocab_size),
                    dtype=torch.float32,
                    device=self.device,
                )
            ]

        # Keep track of sequence lengths
        initial_length = self.output_len + 1  # add the sos token
        seq_lengths = (
            torch.ones([batch_size], dtype=torch.int64, device=self.device)
            * initial_length
        )  # [initial_length, initial_length, ..., initial_length]. This gets reduced whenever it ends somewhere.

        embeds = []  # keep track of the embedded sequence
        sentence_probability = torch.zeros(
            (batch_size, self.vocab_size), device=self.device
        )
        losses_2_3 = torch.empty(self.output_len, device=self.device)
        entropy = torch.empty((batch_size, self.output_len), device=self.device)
        message_logits = torch.empty((batch_size, self.output_len), device=self.device)

        if self.vqvae:
            distance_computer = EmbeddingtableDistances(self.e)

        for i in range(self.output_len):

            emb = torch.matmul(output[-1], self.embedding)

            embeds.append(emb)

            state = self.rnn(emb, state)

            if type(self.rnn) is nn.LSTMCell:
                h, _ = state
            else:
                h = state

            indices = [None] * batch_size

            if not self.rl:
                if not self.vqvae:
                    # That's the original baseline setting
                    p = F.softmax(self.linear_out(h), dim=1)
                    token, sentence_probability = self.calculate_token_gumbel_softmax(
                        p, self.tau, sentence_probability, batch_size
                    )
                else:
                    pre_quant = self.linear_out(h)

                    if not self.discrete_communication:
                        token = self.vq.apply(pre_quant, self.e, indices)
                    else:
                        distances = distance_computer(pre_quant)
                        softmin = F.softmax(-distances, dim=1)
                        if not self.gumbel_softmax:
                            token = self.hard_max.apply(
                                softmin, indices, self.discrete_latent_number
                            )  # This also updates the indices
                        else:
                            _, indices[:] = torch.max(softmin, dim=1)
                            token, _ = self.calculate_token_gumbel_softmax(
                                softmin, self.tau, 0, batch_size
                            )

            else:
                if not self.vqvae:
                    all_logits = F.log_softmax(self.linear_out(h) / self.tau, dim=1)
                else:
                    pre_quant = self.linear_out(h)
                    distances = distance_computer(pre_quant)
                    all_logits = F.log_softmax(-distances / self.tau, dim=1)
                    _, indices[:] = torch.max(all_logits, dim=1)

                distr = Categorical(logits=all_logits)
                entropy[:, i] = distr.entropy()

                if self.training:
                    token_index = distr.sample()
                    token = to_one_hot(token_index, n_dims=self.vocab_size)
                else:
                    token_index = all_logits.argmax(dim=1)
                    token = to_one_hot(token_index, n_dims=self.vocab_size)
                message_logits[:, i] = distr.log_prob(token_index)

            if not (self.vqvae and not self.discrete_communication and not self.rl):
                # Whenever we have a meaningful eos symbol, we prune the messages in the end
                self._calculate_seq_len(
                    seq_lengths, token, initial_length, seq_pos=i + 1
                )

            if self.vqvae:
                loss_2 = torch.mean(
                    torch.norm(pre_quant.detach() - self.e[indices], dim=1) ** 2
                )
                loss_3 = torch.mean(
                    torch.norm(pre_quant - self.e[indices].detach(), dim=1) ** 2
                )
                loss_2_3 = (
                    loss_2 + self.beta * loss_3
                )  # This corresponds to the second and third loss term in VQ-VAE
                losses_2_3[i] = loss_2_3

            token = token.to(self.device)
            output.append(token)

        messages = torch.stack(output, dim=1)
        loss_2_3_out = torch.mean(losses_2_3)

        return (
            messages,
            seq_lengths,
            entropy,
            torch.stack(embeds, dim=1),
            sentence_probability,
            loss_2_3_out,
            message_logits,
        )
Example #27
0
    neglogprobs = torch.zeros((args.episode_length,), device=device)
    entropys = torch.zeros((args.episode_length,), device=device)
    
    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.episode_length):
        global_step += 1
        obs[step] = next_obs.copy()
        
        # ALGO LOGIC: put action logic here
        logits, std = pg.forward([obs[step]])
        values[step] = vf.forward([obs[step]])

        # ALGO LOGIC: `env.action_space` specific logic
        if isinstance(env.action_space, Discrete):
            probs = Categorical(logits=logits)
            action = probs.sample()
            actions[step], neglogprobs[step], entropys[step] = action.tolist()[0], -probs.log_prob(action), probs.entropy()

        elif isinstance(env.action_space, Box):
            probs = Normal(logits, std)
            action = probs.sample()
            clipped_action = torch.clamp(action, torch.min(torch.Tensor(env.action_space.low)), torch.min(torch.Tensor(env.action_space.high)))
            actions[step], neglogprobs[step], entropys[step] = clipped_action.tolist()[0], -probs.log_prob(action).sum(), probs.entropy().sum()
    
        elif isinstance(env.action_space, MultiDiscrete):
            logits_categories = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
            action = []
            probs_categories = []
            probs_entropies = torch.zeros((logits.shape[0]))
            neglogprob = torch.zeros((logits.shape[0]))
            for i in range(len(logits_categories)):
Example #28
0
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.normal import Normal
from torch.distributions.binomial import Binomial
from torch.distributions.categorical import Categorical

S = Categorical(torch.tensor([0.5, 0.5]))
for i in range(100000):
    s = S.sample()
    x = []
    if s==1:
        u
    else
    




#REINFORCE
print ('REINFORCE')

# def sample_reinforce_given_class(logits, samp):    
#     return logprob

grads = []
for i in range (N):

    dist = Categorical(logits=logits)
    samp = dist.sample()
    logprob = dist.log_prob(samp)
    reward = f(samp) 
    gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0]
    grads.append(reward*gradlogprob)
    
print ()
grads = torch.stack(grads).view(N,C)
# print (grads.shape)
grad_mean_reinforce = torch.mean(grads,dim=0)
grad_std_reinforce = torch.std(grads,dim=0)

print ('REINFORCE')
print ('mean:', grad_mean_reinforce)
print ('std:', grad_std_reinforce)
print ()
Example #30
0
    L2_losses = []
    steps_list = []
    for step in range(n_steps):

        optim.zero_grad()

        loss = 0
        net_loss = 0
        for i in range(batch_size):
            x = sample_true()
            logits = encoder.net(x)
            # print (logits.shape)
            # print (torch.softmax(logits, dim=0))
            # fsfd
            cat = Categorical(probs=torch.softmax(logits, dim=0))
            cluster = cat.sample()
            logprob_cluster = cat.log_prob(cluster.detach())
            # print (logprob_cluster)
            pxz = logprob_undercomponent(
                x,
                component=cluster,
                needsoftmax_mixtureweight=needsoftmax_mixtureweight,
                cuda=False)
            f = pxz - logprob_cluster
            # print (f)
            # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight)
            net_loss += -f.detach() * logprob_cluster
            loss += -f
        loss = loss / batch_size
        net_loss = net_loss / batch_size
print('True')
print('[-.5478, .1122, .4422]')
print('dif:', np.abs(grad_mean_simplax.numpy() - true))
print()

#REINFORCE
print('REINFORCE')

# def sample_reinforce_given_class(logits, samp):
#     return logprob

grads = []
for i in range(N):

    dist = Categorical(logits=logits)
    samp = dist.sample()
    logprob = dist.log_prob(samp)
    reward = f(samp)
    gradlogprob = torch.autograd.grad(outputs=logprob,
                                      inputs=(logits),
                                      retain_graph=True)[0]
    grads.append(reward * gradlogprob)

print()
grads = torch.stack(grads).view(N, C)
# print (grads.shape)
grad_mean_reinforce = torch.mean(grads, dim=0)
grad_std_reinforce = torch.std(grads, dim=0)

print('REINFORCE')
print('mean:', grad_mean_reinforce)
Example #32
0
 def throw(self):
     policy = Categorical(self.dist)
     return policy.sample()
Example #33
0
                    tgt_mask=None,
                    tgt_key_padding_mask=None,
                    memory_key_padding_mask=src_key_padding_mask,
                    memory=encoder_output)
                '''
            Now make prediction on next word
            we're only interested in the values of the embeddings of the current 
            step we're on which we can get by output[decoder_step,:,:]
            which has dim (batch_size,vocab_size)
            '''
                word_probs = F.softmax(output[decoding_step, :, :], dim=1)

                #IMPLEMENTATION FOR REINFORCE VERY EASY https://pytorch.org/docs/stable/distributions.html
                #MAKE SURE THAT THIS BACKPROPS THROUGH EVERY ACTION WE CHOOSE
                m = Categorical(word_probs)
                chosen_word = m.sample(
                )  #this generates along dim 1 of word_probs which is what we want since dim 1 contains distributions
                dec_input = torch.cat(
                    [dec_input, chosen_word.view(1, -1)], dim=0
                )  #append chosen_word as row to end of decoder input for next iteration

        policy_net.train()
        reward = get_bleu_scores(trg_tensor,
                                 dec_input,
                                 dataset_dict['TGT'],
                                 BLEU1=False).to(main_params.device)

        #now that we have reward, go back through the path and get gradients
        eos_reached = torch.zeros(main_params.batch_size,
                                  dtype=torch.uint8).type(torch.BoolTensor).to(
                                      main_params.device)
        encoder_output_detached = encoder_output.clone().detach()
Example #34
0
 def get_action(self, x, action=None):
     logits = self.actor(self.forward(x))
     probs = Categorical(logits=logits)
     if action is None:
         action = probs.sample()
     return action, probs.log_prob(action), probs.entropy()
Example #35
0
def get_action(model, device, state):
    state = torch.Tensor(state).to(device)
    action_probs, value_ext, value_int = model(state)
    action_dist = Categorical(action_probs)
    action = action_dist.sample()
    return action.data.cpu().numpy().squeeze()
Example #36
0
    def run_episode(self):
        '''
        Collect experiences from an episode of self-plays. 
        '''

        observations = [[] for i in range(4)]
        actions = [[None] for i in range(4)]
        rewards = [[] for i in range(4)]

        entropy = [[] for i in range(4)]
        log_pbs = [[] for i in range(4)]
        values = [[] for i in range(4)]

        env = make('hungry_geese', debug=False)
        frame = env.reset(num_agents=4)

        while any(entry['status'] == 'ACTIVE' for entry in frame):
            step = frame[0]['observation']['step']
            food = frame[0]['observation']['food']
            geese = frame[0]['observation']['geese']

            for i in range(4):
                agent = geese[i]
                if not agent:
                    continue

                obs = {'index': i, 'geese': geese[:], 'food': food[:]}
                observations[i].append(obs)

                logits, value = self.predict(observations[i])

                # Mask invalid actions to boost learning speed.
                action_mask = get_action_mask(i, geese, actions[i][-1])
                logits = torch.where(action_mask, logits, to_tensor(-1e5))

                policy = Categorical(logits=logits)
                move = policy.sample()

                # Naive reward mechanism to encourage eating.
                # Feel free to change it to improve model performance.
                if is_eating(agent, move, food):
                    rewards[i].append(1)
                else:
                    rewards[i].append(0)

                actions[i].append(action_of(move))
                entropy[i].append(policy.entropy())
                log_pbs[i].append(policy.log_prob(move))
                values[i].append(value)

            # Next frame
            frame = env.step(tails_of(actions))

        # Episode is over.
        # Assign final reward for each agent.
        for i in range(4):
            score = frame[i]['reward']
            turns, length = divmod(score, 100)

            # Encourage surviving.
            rewards[i][-1] += turns / 200

            # Ensure shapes are consistent.
            assert len(rewards[i]) == len(log_pbs[i]) == len(values[i])

        # Calculate actual returns.
        Q = []
        for i in range(4):
            length = len(rewards[i])
            returns = torch.zeros(length).detach()
            val = 0
            for t in reversed(range(length)):
                val = rewards[i][t] + self.GAMMA * val
                returns[t] = val

            Q.append(returns)

        # Flatten training data.
        Q = torch.hstack(Q).detach()
        V = torch.hstack(flatten(values))
        E = torch.hstack(flatten(entropy))
        log_probs = torch.hstack(flatten(log_pbs))

        # Again ensure shapes of training data are consistent.
        assert Q.shape == V.shape == E.shape == log_probs.shape

        self.memory.add(Q, V, E, log_probs)
Example #37
0
    def make_action(self, state, test=False):
        if test:
            state = torch.tensor(state,
                                 device='cuda' if use_cuda else 'cpu').permute(
                                     2, 0, 1).unsqueeze(0)

        if self.args.exploration_method.startswith('greedy'):
            with torch.no_grad():
                actions = torch.softmax(self.online_net(state),
                                        1).max(1)[1].view(-1, 1)
            if test:
                return actions.item()
            return actions

        elif self.args.exploration_method.startswith('epsilon'):
            # TODO:
            # At first, you decide whether you want to explore the environemnt
            sample = random.random()

            if self.args.exploration_method.startswith('epsilon_exp'):
                global episodes_done_num
                EPS_START = 0.9
                EPS_END = 0.1
                EPS_DECAY = 200
                eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(
                    -1. * episodes_done_num / EPS_DECAY)
            else:
                eps_threshold = .1

            # TODO:
            # if explore, you randomly samples one action
            # else, use your model to predict action
            if sample > eps_threshold or test:
                with torch.no_grad():
                    # t.max(1) will return largest column value of each row.
                    # second column on max result is index of where max element was
                    # found, so we pick action with the larger expected reward.
                    if test:
                        return self.online_net(state).max(1)[1].view(1,
                                                                     1).item()
                    return self.online_net(state).max(1)[1].view(1, 1)
            else:
                return torch.tensor([[random.randrange(self.num_actions)]],
                                    device='cuda' if use_cuda else 'cpu',
                                    dtype=torch.long)

        elif self.args.exploration_method.startswith('boltzmann'):
            with torch.no_grad():
                probs = torch.softmax(
                    self.online_net(state) / self.args.boltzmann_temperature,
                    1)
            m = Categorical(probs)
            action = m.sample().view(-1, 1)
            if test:
                return action.item()
            return action

        elif self.args.exploration_method.startswith('thompson'):
            with torch.no_grad():
                if test:
                    probs = torch.softmax(
                        self.online_net.forward(state,
                                                dropout_rate=0,
                                                thompson=False), 1)
                else:
                    probs = torch.softmax(
                        self.online_net.forward(state,
                                                dropout_rate=0.3,
                                                thompson=True), 1)
            actions = probs.max(1)[1].view(-1, 1)
            if test:
                return actions.item()
            return actions

        else:
            raise ValueError("Unknown exploration method")
Example #38
0
    def run(self, episodes, steps, train=False, render_once=1e10, saveonce=10):
        if train:
            assert self.recorder.log_message is not None, "log_message is necessary during training, Instantiate Runner with log message"

        reset_model = False
        if hasattr(self.model, "type") and self.model.type == "mem":
            print("Recurrent Model")
            reset_model = True
        self.env.display_neural_image = self.visual_activations
        for _ in range(episodes):

            self.env.reset()
            self.env.enable_draw = True if not train or _ % render_once == render_once - 1 else False

            if reset_model:
                self.model.reset()

            state = self.env.get_state().reshape(-1)
            bar = tqdm(range(steps),
                       bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
            trewards = 0

            for step in bar:

                state = T.from_numpy(state).float()
                actions = self.model(state)

                c = Categorical(actions)
                action = c.sample()
                log_prob = c.log_prob(action)

                u = np.zeros(self.nactions)
                u[action] = 1.0
                newstate, reward = self.env.act(u)
                state = newstate.reshape(-1)
                trewards += reward

                if train:
                    self.trainer.store_records(reward, log_prob)

                if self.visual_activations:
                    u = T.cat(self.activations, dim=0).reshape(-1)
                    self.env.neural_image_values = u.detach().numpy()
                    self.activations = []
                    if _ % 10 == 0 and step / steps == 0:
                        self.update_weights()
                        self.env.neural_weights = self.weights
                        self.env.weight_change = True
                    if type(self.model.hidden_vectors) != type(None):
                        self.env.hidden_state = self.model.hidden_vectors

                bar.set_description(f"Episode: {_:4} Rewards : {trewards}")
                if train:
                    self.env.step()
                else:
                    self.env.step(speed=0)

            if train:
                self.trainer.update()
                self.trainer.clear_memory()
                self.recorder.newdata(trewards)
                if _ % saveonce == saveonce - 1:
                    self.recorder.save()
                    self.recorder.plot()

                if _ % saveonce == saveonce - 1 and self.recorder.final_reward >= self.current_max_reward:
                    self.recorder.save_model(self.model)
                    self.current_max_reward = self.recorder.final_reward
        print("******* Run Complete *******")
Example #39
0
    def forward(self):
        '''
        https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L126
        '''
        h0 = None  # setting h0 to None will initialize LSTM state with 0s

        anchors = []
        anchors_w_1 = []

        arc_seq = {}
        entropys = []
        log_probs = []
        skip_count = []
        skip_penaltys = []

        inputs = self.g_emb.weight
        skip_targets = torch.tensor([1.0 - self.skip_target,
                                     self.skip_target]).cuda()

        for layer_id in range(self.num_layers):
            if self.search_whole_channels:
                inputs = inputs.unsqueeze(0)
                output, hn = self.w_lstm(inputs, h0)
                output = output.squeeze(0)
                h0 = hn

                logit = self.w_soft(output)
                if self.temperature is not None:
                    logit /= self.temperature
                if self.tanh_constant is not None:
                    logit = self.tanh_constant * torch.tanh(logit)

                branch_id_dist = Categorical(logits=logit)
                branch_id = branch_id_dist.sample()

                arc_seq[str(layer_id)] = [branch_id]

                log_prob = branch_id_dist.log_prob(branch_id)
                log_probs.append(log_prob.view(-1))
                entropy = branch_id_dist.entropy()
                entropys.append(entropy.view(-1))

                inputs = self.w_emb(branch_id)
                inputs = inputs.unsqueeze(0)
            else:
                # https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L171
                assert False, "Not implemented error: search_whole_channels = False"

            output, hn = self.w_lstm(inputs, h0)
            output = output.squeeze(0)

            if layer_id > 0:
                query = torch.cat(anchors_w_1, dim=0)
                query = torch.tanh(query + self.w_attn_2(output))
                query = self.v_attn(query)
                logit = torch.cat([-query, query], dim=1)
                if self.temperature is not None:
                    logit /= self.temperature
                if self.tanh_constant is not None:
                    logit = self.tanh_constant * torch.tanh(logit)

                skip_dist = Categorical(logits=logit)
                skip = skip_dist.sample()
                skip = skip.view(layer_id)

                arc_seq[str(layer_id)].append(skip)

                skip_prob = torch.sigmoid(logit)
                kl = skip_prob * torch.log(skip_prob / skip_targets)
                kl = torch.sum(kl)
                skip_penaltys.append(kl)

                log_prob = skip_dist.log_prob(skip)
                log_prob = torch.sum(log_prob)
                log_probs.append(log_prob.view(-1))

                entropy = skip_dist.entropy()
                entropy = torch.sum(entropy)
                entropys.append(entropy.view(-1))

                # Calculate average hidden state of all nodes that got skips
                # and use it as input for next step
                skip = skip.type(torch.float)
                skip = skip.view(1, layer_id)
                skip_count.append(torch.sum(skip))
                inputs = torch.matmul(skip, torch.cat(anchors, dim=0))
                inputs /= (1.0 + torch.sum(skip))

            else:
                inputs = self.g_emb.weight

            anchors.append(output)
            anchors_w_1.append(self.w_attn_1(output))

        self.sample_arc = arc_seq

        entropys = torch.cat(entropys)
        self.sample_entropy = torch.sum(entropys)

        log_probs = torch.cat(log_probs)
        self.sample_log_prob = torch.sum(log_probs)

        skip_count = torch.stack(skip_count)
        self.skip_count = torch.sum(skip_count)

        skip_penaltys = torch.stack(skip_penaltys)
        self.skip_penaltys = torch.mean(skip_penaltys)
Example #40
0
def train(model,
          iterator,
          optimizer,
          criterion,
          tag_pad_idx,
          tag_unk_idx,
          inside_word_idx,
          UD_TAGS=None,
          noise=0):

    epoch_loss = 0
    epoch_correct = 0
    epoch_n_label = 0

    model.train()

    if noise > 0:
        counts = [
            UD_TAGS.vocab.freqs[UD_TAGS.vocab.itos[k]]
            if UD_TAGS.vocab.itos[k] in UD_TAGS.vocab.freqs else 0
            for k in range(len(UD_TAGS.vocab))
        ]
        c = Categorical(
            torch.tensor(counts).cuda() /
            float(sum(UD_TAGS.vocab.freqs.values())))
        b = Bernoulli(probs=torch.tensor([noise]).cuda())

    for batch in tqdm(iterator):

        text = batch.text
        tags = batch.udtags

        optimizer.zero_grad()

        # text = [sent len, batch size]

        predictions = model(text)

        # predictions = [sent len, batch size, output dim]
        # tags = [sent len, batch size]

        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)

        if noise > 0:
            assert noise >= 0 and noise <= 1
            non_pad_elements = tags != tag_pad_idx
            prob_mask = (b.sample(tags.shape)
                         == 1).squeeze(1).cuda() & non_pad_elements
            noisy_preds = c.sample((torch.sum(prob_mask).item(), ))
            # noisy_preds = random.choices([UD_TAGS.vocab.stoi[elt] for elt in UD_TAGS.vocab.freqs.keys()], k=torch.sum(prob_mask).item())
            tags[prob_mask] = torch.tensor(noisy_preds)

        # predictions = [sent len * batch size, output dim]
        # tags = [sent len * batch size]

        loss = criterion(predictions, tags)

        correct, n_labels = categorical_accuracy(predictions, tags,
                                                 tag_pad_idx, tag_unk_idx,
                                                 inside_word_idx)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_correct += correct.item()
        epoch_n_label += n_labels

    return epoch_loss / len(iterator), epoch_correct / epoch_n_label
Example #41
0
 def get_action(self, state):
     """interface for Agent"""
     s = torch.FloatTensor(state).to(self.device)
     logits = self.model(s).detach()
     m = Categorical(logits = logits)
     return m.sample().cpu().data.numpy().tolist()[0]
Example #42
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by :attr:`probs` or
    :attr:`logits`.

    Samples are one-hot coded vectors of size ``probs.size(-1)``.

    .. note:: :attr:`probs` will be normalized to be summing to 1.

    See also: :func:`torch.distributions.Categorical` for specifications of
    :attr:`probs` and :attr:`logits`.

    Example::

        >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
        tensor([ 0.,  0.,  0.,  1.])

    Args:
        probs (Tensor): event probabilities
        logits (Tensor): event log probabilities
    """
    arg_constraints = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def probs(self):
        return self._categorical.probs

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def mean(self):
        return self._categorical.probs

    @property
    def variance(self):
        return self._categorical.probs * (1 - self._categorical.probs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        n = self.event_shape[0]
        values = self._new((n, n))
        torch.eye(n, out=values)
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
Example #43
0
 def decode_teacher_forcing(self, y, vl, vg):
     # Masks
     not_masked = y.new_ones(1, dtype=torch.bool)[0]
     mask = ((y > 0).sum(dim=(0, 1)) > 0)
     # Initialize word states
     w = vg.new_full((vg.shape[0], 1),
                     PretrainedEmbeddings.INDEX_START,
                     dtype=torch.long)
     vg = self.dropout(vg)
     h = self.init_h(vg)
     c = self.init_c(vg)
     states = (vl, h, c, None)
     # Process words
     words, hs, alphas = [], [], []
     for j in range(self.max_word):
         if torch.equal(mask[j], not_masked):
             p, states = self.proc_word(w, states)
             words.append(p)
             _, h, _, alpha = states
             hs.append(h)
             if alpha is not None:
                 alphas.append(alpha)
             if self.teacher_forcing is None or self.teacher_forcing.get_tfr(
             ) >= random.random():
                 w = y[:, 0, j]
             else:
                 p = softmax(p, dim=-1)
                 cat = Categorical(probs=p)
                 w = cat.sample()
         else:
             p = vg.new_ones(vg.shape[0], self.embed_num) / self.embed_num
             words.append(p)
             if self.teacher_forcing is None or self.teacher_forcing.get_tfr(
             ) >= random.random():
                 w = y[:, 0, j]
             else:
                 w = vg.new_zeros(vg.shape[0])
     words = torch.stack(words, dim=1)
     # Attention Encoded Text Embedding
     # Concat hidden states NxTxd
     hs = torch.stack(hs, dim=1)
     # Projection to NxTxr rows
     g = softmax(self.aete2(tanh(self.aete1(hs))), dim=2)
     # Weighted sum over T to Nxrxd
     m = g.permute(0, 2, 1).matmul(hs)
     # AETE embedding Nxd
     x1 = m.max(dim=1)[0]
     # Saliency Weighted Global Average Pooling
     alphas = torch.stack(alphas, dim=1)
     if self.multi_image > 1:
         # Spatial attention maps NxMxT
         aws = (alphas *
                g.max(dim=2)[0].unsqueeze(dim=-1).unsqueeze(dim=-1)).sum(
                    dim=1)
         # SWGAP Nx1024
         x2 = (aws.unsqueeze(dim=-1) * vl).sum(dim=2)
         x2 = x2.max(dim=1)[0]
     else:
         # Spatial attention maps NxT
         aws = (alphas * g.max(dim=2)[0].unsqueeze(dim=-1)).sum(dim=1)
         # SWGAP Nx1024
         x2 = (aws.unsqueeze(dim=-1) * vl).sum(dim=1)
     # Joint Nx14
     dis = self.joint(torch.cat([x1, x2], dim=1))
     dis = dis.view((dis.shape[0], self.DISEASE_NUM, 2))
     return words, dis