class VAERNN(torch.nn.Module): def __init__(self): super(VAERNN, self).__init__() self.z_size = 32 self.kl_tolerance = 0.5 self.vae = VAE() self.rnn = RNN() self.vae.train() self.rnn.train() self.init_() self.is_cuda = False def load(self): self.vae.load_state_dict(torch.load(vae_model_path, map_location=lambda storage, loc: storage)) self.rnn.load_state_dict(torch.load(rnn_model_path, map_location=lambda storage, loc: storage)) def init_(self): self.h = self.rnn.init_() def forward(self, inputs): z = self.vae(inputs) return z def when_train(self, inputs, one, outputs): if self.is_cuda: self.vae.is_cuda = True self.vae.cuda() self.rnn.is_cuda = True self.rnn.cuda() # self.rnn.init_() z = self.vae(inputs) # z = self.vae(inputs) # self.next_kl_loss = self.vae.kl_loss # self.next_r_loss = self.vae.r_loss z = z.unsqueeze(0) z_a = torch.cat((z, one), dim=2) self.rnn(z_a) z_next = self.vae(outputs) self.next_kl_loss = self.vae.kl_loss self.next_r_loss = self.vae.r_loss z_next = z_next.unsqueeze(0) # z_next = z self.pred_loss = self.rnn.prediction_loss_f(z_next) self.mdn_loss = self.rnn.mdn_loss_f(z_next)
class VAERNN(torch.nn.Module): def __init__(self): super(VAERNN, self).__init__() self.z_size = 32 self.kl_tolerance = 0.5 self.vae = VAE() self.rnn = RNN() self.vae.load_state_dict( torch.load(vae_model_path, map_location=lambda storage, loc: storage)) self.rnn.load_state_dict( torch.load(rnn_model_path, map_location=lambda storage, loc: storage)) self.vae.train() self.rnn.train() self.init_() self.is_cuda = False def init_(self): self.h = self.rnn.init_() def forward(self, inputs): z = self.vae(inputs) # z = z.unsqueeze(0) # z = self.rnn(z) print('z', z.shape) print('h', self.h.shape) return z, self.h def when_train(self, inputs, one, outputs): self.vae.is_cuda = True self.vae.cuda() self.rnn.is_cuda = True self.rnn.cuda() # print('inputs outputs') # print(inputs.shape) # print(outputs.shape) with torch.no_grad(): z = self.vae(inputs) # print(z.shape) z = z.unsqueeze(0) # print(z.shape) z_a = torch.cat((z, one), dim=2) self.rnn(z_a) z_next = self.vae(outputs) self.next_kl_loss = self.vae.kl_loss self.next_r_loss = self.vae.r_loss # print('z_next', z_next.shape) # print(next_kl_loss.shape) # print(next_r_loss.shape) # print('rnn now') # print(self.rnn.z_prediction.shape) z_next = z_next.unsqueeze(0) # print(z_next.shape) # input('hi') self.pred_loss = self.rnn.prediction_loss_f(z_next) self.mdn_loss = self.rnn.mdn_loss_f(z_next) # print(pred_loss.shape) # print(mdn_loss.shape) z_next_hat = self.rnn.z_prediction # print('making v m error') # print(z_next_hat.shape) # print(outputs.shape) z_next_hat = z_next_hat.squeeze(0) self.pred_recon_loss = self.vae.reconstruction_error_f( z_next_hat, inputs) # print(pred_recon_loss.shape) ''' w = self.rnn.logweight_mdn m = self.rnn.mean_mdn s = self.rnn.logstd_mdn print('w', w.shape) print(w[0, 0, 0]) a = w[0, 0, 0] b = torch.exp(a) print(b) n = b.multinomial(num_samples=1).data print(n) weight = torch.exp(w) ns = weight.multinomial(num_samples=1).data print(ns.shape) c = weight[0, 0] d = c.multinomial(num_samples=1).data print(c.shape) print(d.shape) weight = weight.squeeze(0) print('ww', weight.shape) a = torch.reshape(weight, (-1, 5)) print(a.shape) d = a.multinomial(num_samples=5).data print('d is ', d.shape) b = torch.reshape(d, (-1, 32, 5)) print(b.shape) #c = (weight==b) #print(c.shape) #print(c[200,30,4]) c = b[:,:,0:1] c = c.unsqueeze(0) print(c[0,250,20,0]) print(c[0,c[0,250,20,0],20,0]) print(c.shape) samples = c # z_a = z_a.unsqueeze(0) ''' # print(z_a.shape) def make_prediction(self, action): one = one_hot(action) one = torch.from_numpy(one) one = one.unsqueeze(0) one = one.type(torch.float) z_a = torch.cat((z, one), dim=1) z_a = z_a.unsqueeze(0)
# print(z.shape) # print(a.shape) inputs = np.concatenate((z[:-1, :], one[:-1, :]), axis=1) outputs = z[1:, :] # print(inputs.shape) inputs = tensor_rnn_inputs(inputs) outputs = tensor_rnn_inputs(outputs) if is_cuda: inputs = inputs.cuda() outputs = outputs.cuda() h = rnn_model(inputs) # print(h.shape) print(outputs.shape) input('hi') mdn_loss = rnn_model.mdn_loss_f(outputs) pred_loss = rnn_model.prediction_loss_f(outputs) loss = mdn_loss + pred_loss optimizer.zero_grad() loss.backward() optimizer.step() mdn_loss_s.append(mdn_loss.item()) pred_loss_s.append(pred_loss.item()) # print(l.shape) print('epoch: {}, mdn loss: {}, prediction loss: {}'.format( epoch, np.mean(mdn_loss_s), np.mean(pred_loss_s))) if (epoch + 1) % 20 == 0: torch.save(rnn_model.state_dict(), rnn_model_path) torch.save(rnn_model.state_dict(), rnn_model_path)