def forward_thru_model(self, model, inputs, actions, targets): input_images, input_states = inputs bsize = input_images.size(0) npred = actions.size(1) ploss = torch.zeros(1).cuda() for t in range(npred): h_x = model.encoder(input_images, input_states) target_images, target_states, target_costs = targets h_y = model.y_encoder(target_images[:, t].unsqueeze(1).contiguous()) z = model.z_network((h_x + h_y).view(bsize, -1)) pi, mu, sigma = self(input_images, input_states) # prior loss ploss += utils.mdn_loss_fn(pi, sigma, mu, z) z_exp = model.z_expander(z).view(bsize, model.opt.nfeature, model.opt.h_height, model.opt.h_width) h_x = h_x.view(bsize, model.opt.nfeature, model.opt.h_height, model.opt.h_width) a_emb = model.a_encoder(actions[:, t]).view(h_x.size()) h = h_x + z_exp h = h + a_emb h = h + model.u_network(h) pred_image, pred_state, pred_cost = model.decoder(h) pred_image.detach() pred_state.detach() pred_cost.detach() pred_image = torch.sigmoid(pred_image + input_images[:, -1].unsqueeze(1)) # since these are normalized, we are clamping to 6 standard deviations (if gaussian) pred_state = pred_state + input_states[:, -1] # pred_state = torch.clamp(pred_state + input_states[:, -1], min=-6, max=6) input_images = torch.cat((input_images[:, 1:], pred_image), 1) input_states = torch.cat((input_states[:, 1:], pred_state.unsqueeze(1)), 1) return ploss / npred
def test(nbatches): policy.eval() total_loss, nb = 0, 0 for i in range(nbatches): inputs, actions, targets, _, _ = dataloader.get_batch_fm('valid') pi, mu, sigma, _ = policy(inputs[0], inputs[1]) loss = utils.mdn_loss_fn(pi, sigma, mu, actions.view(opt.batch_size, -1)) if not math.isnan(loss.item()): total_loss += loss.item() nb += 1 else: print('warning, NaN') return total_loss / nb
def train(nbatches): policy.train() total_loss, nb = 0, 0 for i in range(nbatches): optimizer.zero_grad() inputs, actions, targets, _, _ = dataloader.get_batch_fm('train') pi, mu, sigma, _ = policy(inputs[0], inputs[1]) loss = utils.mdn_loss_fn(pi, sigma, mu, actions.view(opt.batch_size, -1)) if not math.isnan(loss.item()): loss.backward() if opt.grad_clip != -1: torch.nn.utils.clip_grad_norm_(policy.parameters(), opt.grad_clip) optimizer.step() total_loss += loss.item() nb += 1 else: print('warning, NaN') return total_loss / nb