Beispiel #1
0
class VAERNN(torch.nn.Module):
    def __init__(self):
        super(VAERNN, self).__init__()

        self.z_size = 32
        self.kl_tolerance = 0.5

        self.vae = VAE()
        self.rnn = RNN()

        self.vae.train()
        self.rnn.train()
        self.init_()

        self.is_cuda = False


    def load(self):
        self.vae.load_state_dict(torch.load(vae_model_path, map_location=lambda storage, loc: storage))
        self.rnn.load_state_dict(torch.load(rnn_model_path, map_location=lambda storage, loc: storage))


    def init_(self):
        self.h = self.rnn.init_()

    def forward(self, inputs):
        z = self.vae(inputs)
        return z

    def when_train(self, inputs, one, outputs):
        if self.is_cuda:
            self.vae.is_cuda = True
            self.vae.cuda()
            self.rnn.is_cuda = True
            self.rnn.cuda()


        # self.rnn.init_()

        z = self.vae(inputs)
        # z = self.vae(inputs)
        # self.next_kl_loss = self.vae.kl_loss
        # self.next_r_loss = self.vae.r_loss
        z = z.unsqueeze(0)


        z_a = torch.cat((z, one), dim=2)
        self.rnn(z_a)
        z_next = self.vae(outputs)
        self.next_kl_loss = self.vae.kl_loss
        self.next_r_loss = self.vae.r_loss

        z_next = z_next.unsqueeze(0)
        # z_next = z

        self.pred_loss = self.rnn.prediction_loss_f(z_next)
        self.mdn_loss = self.rnn.mdn_loss_f(z_next)
class VAERNN(torch.nn.Module):
    def __init__(self):
        super(VAERNN, self).__init__()

        self.z_size = 32
        self.kl_tolerance = 0.5

        self.vae = VAE()
        self.rnn = RNN()

        self.vae.train()
        self.rnn.train()
        self.init_()

        self.is_cuda = False

    def load(self):
        self.vae.load_state_dict(
            torch.load(vae_model_path,
                       map_location=lambda storage, loc: storage))
        self.rnn.load_state_dict(
            torch.load(rnn_model_path,
                       map_location=lambda storage, loc: storage))

    def init_(self):
        self.h = self.rnn.init_()

    def forward(self, inputs):
        z = self.vae(inputs)
        return z

    def when_train(self, inputs, one, outputs):
        if self.is_cuda:
            self.vae.is_cuda = True
            self.vae.cuda()
            self.rnn.is_cuda = True
            self.rnn.cuda()

        # with torch.no_grad():
        z = self.vae(inputs)
        self.next_kl_loss = self.vae.kl_loss
        self.next_r_loss = self.vae.r_loss
class VAERNN(torch.nn.Module):
    def __init__(self):
        super(VAERNN, self).__init__()

        self.z_size = 32
        self.kl_tolerance = 0.5

        self.vae = VAE()
        self.rnn = RNN()
        self.vae.load_state_dict(
            torch.load(vae_model_path,
                       map_location=lambda storage, loc: storage))
        self.rnn.load_state_dict(
            torch.load(rnn_model_path,
                       map_location=lambda storage, loc: storage))
        self.vae.train()
        self.rnn.train()
        self.init_()

        self.is_cuda = False

    def init_(self):
        self.h = self.rnn.init_()

    def forward(self, inputs):
        z = self.vae(inputs)
        # z = z.unsqueeze(0)
        # z = self.rnn(z)
        print('z', z.shape)
        print('h', self.h.shape)
        return z, self.h

    def when_train(self, inputs, one, outputs):

        self.vae.is_cuda = True
        self.vae.cuda()
        self.rnn.is_cuda = True
        self.rnn.cuda()

        # print('inputs outputs')
        # print(inputs.shape)
        # print(outputs.shape)
        with torch.no_grad():
            z = self.vae(inputs)
        # print(z.shape)
        z = z.unsqueeze(0)
        # print(z.shape)

        z_a = torch.cat((z, one), dim=2)
        self.rnn(z_a)
        z_next = self.vae(outputs)
        self.next_kl_loss = self.vae.kl_loss
        self.next_r_loss = self.vae.r_loss
        # print('z_next', z_next.shape)
        # print(next_kl_loss.shape)
        # print(next_r_loss.shape)
        # print('rnn now')
        # print(self.rnn.z_prediction.shape)
        z_next = z_next.unsqueeze(0)
        # print(z_next.shape)
        # input('hi')
        self.pred_loss = self.rnn.prediction_loss_f(z_next)
        self.mdn_loss = self.rnn.mdn_loss_f(z_next)
        # print(pred_loss.shape)
        # print(mdn_loss.shape)
        z_next_hat = self.rnn.z_prediction
        # print('making v m error')
        # print(z_next_hat.shape)
        # print(outputs.shape)
        z_next_hat = z_next_hat.squeeze(0)
        self.pred_recon_loss = self.vae.reconstruction_error_f(
            z_next_hat, inputs)
        # print(pred_recon_loss.shape)
        '''
        w = self.rnn.logweight_mdn
        m = self.rnn.mean_mdn
        s = self.rnn.logstd_mdn
        print('w', w.shape)
        print(w[0, 0, 0])
        a = w[0, 0, 0]
        b = torch.exp(a)
        print(b)
        n = b.multinomial(num_samples=1).data
        print(n)
        weight = torch.exp(w)
        ns = weight.multinomial(num_samples=1).data
        print(ns.shape)
        c = weight[0, 0]
        d = c.multinomial(num_samples=1).data
        print(c.shape)
        print(d.shape)
        weight = weight.squeeze(0)
        print('ww', weight.shape)
        a = torch.reshape(weight, (-1, 5))
        print(a.shape)
        d = a.multinomial(num_samples=5).data
        print('d is ', d.shape)
        b = torch.reshape(d, (-1, 32, 5))
        print(b.shape)
        #c = (weight==b)
        #print(c.shape)
        #print(c[200,30,4])
        c = b[:,:,0:1]
        c = c.unsqueeze(0)
        print(c[0,250,20,0])
        print(c[0,c[0,250,20,0],20,0])
        print(c.shape)
        samples = c
        # z_a = z_a.unsqueeze(0)
        '''
        # print(z_a.shape)

    def make_prediction(self, action):
        one = one_hot(action)
        one = torch.from_numpy(one)
        one = one.unsqueeze(0)
        one = one.type(torch.float)
        z_a = torch.cat((z, one), dim=1)
        z_a = z_a.unsqueeze(0)
Beispiel #4
0
env = gym.make('Pong-v0')

vae_model = VAE()
vae_model.load_state_dict(torch.load(vae_model_path))
vae_model.eval()

rnn_model = RNN()
rnn_model.load_state_dict(torch.load(rnn_model_path))
rnn_model.eval()

policy = Policy()
policy.train()
optimizer = optim.Adam(policy.parameters())
while True:
    state = env.reset()
    h = rnn_model.init_()
    value_s = []
    action_log_prob_s = []
    entropy_s = []
    reward_s = []
    reward_sum = 0
    while True:
        # env.render()
        state = tensor_state(state)
        z = vae_model(state)
        h = h.squeeze(0)
        z_h = torch.cat((z, h), dim=1)
        a = policy(z_h)

        one = one_hot(a)
        one = torch.from_numpy(one)