Example #1
0
    def forward_thru_model(self, model, inputs, actions, targets):
        input_images, input_states = inputs
        bsize = input_images.size(0)
        npred = actions.size(1)
        ploss = torch.zeros(1).cuda()

        for t in range(npred):
            h_x = model.encoder(input_images, input_states)
            target_images, target_states, target_costs = targets
            h_y = model.y_encoder(target_images[:, t].unsqueeze(1).contiguous())
            z = model.z_network((h_x + h_y).view(bsize, -1))
            pi, mu, sigma = self(input_images, input_states)
            # prior loss
            ploss += utils.mdn_loss_fn(pi, sigma, mu, z)
            z_exp = model.z_expander(z).view(bsize, model.opt.nfeature, model.opt.h_height, model.opt.h_width)
            h_x = h_x.view(bsize, model.opt.nfeature, model.opt.h_height, model.opt.h_width)
            a_emb = model.a_encoder(actions[:, t]).view(h_x.size())
            h = h_x + z_exp
            h = h + a_emb
            h = h + model.u_network(h)
            pred_image, pred_state, pred_cost = model.decoder(h)

            pred_image.detach()
            pred_state.detach()
            pred_cost.detach()
            pred_image = torch.sigmoid(pred_image + input_images[:, -1].unsqueeze(1))
            # since these are normalized, we are clamping to 6 standard deviations (if gaussian)
            pred_state = pred_state + input_states[:, -1]
            # pred_state = torch.clamp(pred_state + input_states[:, -1], min=-6, max=6)
            input_images = torch.cat((input_images[:, 1:], pred_image), 1)
            input_states = torch.cat((input_states[:, 1:], pred_state.unsqueeze(1)), 1)

        return ploss / npred
Example #2
0
def test(nbatches):
    policy.eval()
    total_loss, nb = 0, 0
    for i in range(nbatches):
        inputs, actions, targets, _, _ = dataloader.get_batch_fm('valid')
        pi, mu, sigma, _ = policy(inputs[0], inputs[1])
        loss = utils.mdn_loss_fn(pi, sigma, mu, actions.view(opt.batch_size, -1))
        if not math.isnan(loss.item()):
            total_loss += loss.item()
            nb += 1
        else:
            print('warning, NaN')
    return total_loss / nb
Example #3
0
def train(nbatches):
    policy.train()
    total_loss, nb = 0, 0
    for i in range(nbatches):
        optimizer.zero_grad()
        inputs, actions, targets, _, _ = dataloader.get_batch_fm('train')
        pi, mu, sigma, _ = policy(inputs[0], inputs[1])
        loss = utils.mdn_loss_fn(pi, sigma, mu, actions.view(opt.batch_size, -1))
        if not math.isnan(loss.item()):
            loss.backward()
            if opt.grad_clip != -1:
                torch.nn.utils.clip_grad_norm_(policy.parameters(), opt.grad_clip)
            optimizer.step()
            total_loss += loss.item()
            nb += 1
        else:
            print('warning, NaN')
    return total_loss / nb