torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B, 1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs = torch.stack(zs)
    z_tildes = torch.stack(z_tildes)

    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob


surrogate = NN3(input_size=C, output_size=1, n_residual_blocks=2)

train_ = 1
n_steps = 300
if train_:
    optim = torch.optim.Adam(surrogate.parameters(),
                             lr=1e-4,
                             weight_decay=1e-7)
    #Train surrogate
    for i in range(n_steps):
        for c in range(C):
            samp = torch.tensor([c]).view(B, 1)
            z, z_tilde, logprob = sample_relax_given_class(logits, samp)

            cz_tilde = surrogate.net(z_tilde)
            reward = f(samp)
grad_mean_reinforce = torch.mean(grads, dim=0)
grad_std_reinforce = torch.std(grads, dim=0)

print('REINFORCE')
print('mean:', grad_mean_reinforce)
print('std:', grad_std_reinforce)
print()
# print ('True')
# print ('[-.5478, .1122, .4422]')
# print ('dif:', np.abs(grad_mean_reinforce.numpy() -  true))
# print ()

#REINFORCE
print('REINFORCE with critic')

critic = NN3(input_size=C, output_size=1, n_residual_blocks=1)
optim = torch.optim.Adam(critic.parameters(), lr=1e-4, weight_decay=1e-7)
for i in range(500):

    dist = Categorical(logits=logits)
    b = dist.sample()
    reward = f(b)

    c = critic.net(logits)
    loss = torch.mean((reward - c)**2)

    optim.zero_grad()
    loss.backward(retain_graph=True)
    optim.step()

    if i % 100 == 0:
示例#3
0
                    shell=True)

    train_ = 1  # if 0, it loads data

    if train_:

        logits = torch.log(torch.tensor(np.array([.5]))).float()
        logits.requires_grad_(True)

        print()
        print('RELAX')
        print('Value:', val)
        print()

        # net = NN()
        surrogate = NN(input_size=1, output_size=1, n_residual_blocks=2)
        optim = torch.optim.Adam([logits], lr=.004)
        optim_NN = torch.optim.Adam(surrogate.parameters(), lr=.00005)

        steps = []
        losses10 = []
        zs = []
        for step in range(total_steps + 1):

            # batch=[]

            optim_NN.zero_grad()

            losses = 0
            for ii in range(10):
                #Sample p(z)
    # dist = LogitRelaxedBernoulli(torch.Tensor([1.]), bern_param)
    # dist_bernoulli = Bernoulli(bern_param)
    C = 2
    n_components = C
    B = 1
    # probs = torch.ones(B,C)
    bern_param = bern_param.view(B, 1)
    aa = 1 - bern_param
    probs = torch.cat([aa, bern_param], dim=1)
    # probs[:,0] = probs[:,0]* bern_param
    # probs[:,1] = probs[:,1] - bern_param

    cat = Categorical(probs=probs)
    surrogate = NN3(input_size=n_components,
                    output_size=1,
                    n_residual_blocks=2)  #.cuda()

    def sample_relax(probs):
        #Sample z
        u = torch.rand(B, C)
        gumbels = -torch.log(-torch.log(u))
        z = torch.log(probs) + gumbels

        b = torch.argmax(z, dim=1)
        logprob = cat.log_prob(b)

        #Sample z_tilde
        u_b = torch.rand(B, 1)
        z_tilde_b = -torch.log(-torch.log(u_b))
        u = torch.rand(B, C)