def test(nbatches, npred): gamma_mask = torch.Tensor([opt.gamma**t for t in range(npred)]).view(1, -1).cuda() model.eval() total_loss_i, total_loss_s, total_loss_c, total_loss_policy, total_loss_p, n_updates = 0, 0, 0, 0, 0, 0 for i in range(nbatches): inputs, actions, targets, _, _ = dataloader.get_batch_fm('test', npred) pred, pred_actions = planning.train_policy_net_mper(model, inputs, targets, targetprop = opt.targetprop, dropout=0.0, model_type = model_type) loss_i, loss_s, loss_c_, loss_p = compute_loss(targets, pred) loss_policy = loss_i + loss_s if opt.loss_c == 1: loss_policy += loss_c_ if not math.isnan(loss_policy.item()): total_loss_i += loss_i.item() total_loss_s += loss_s.item() total_loss_p += loss_p.item() total_loss_policy += loss_policy.item() n_updates += 1 del inputs, actions, targets, pred total_loss_i /= n_updates total_loss_s /= n_updates total_loss_c /= n_updates total_loss_policy /= n_updates total_loss_p /= n_updates return total_loss_i, total_loss_s, total_loss_c, total_loss_policy, total_loss_p
def train(nbatches, npred): gamma_mask = torch.Tensor([opt.gamma**t for t in range(npred)]).view(1, -1).cuda() model.eval() model.policy_net.train() total_loss_i, total_loss_s, total_loss_c, total_loss_policy, total_loss_p, n_updates = 0, 0, 0, 0, 0, 0 for i in range(nbatches): optimizer.zero_grad() inputs, actions, targets, _, _ = dataloader.get_batch_fm( 'train', npred) inputs = utils.make_variables(inputs) targets = utils.make_variables(targets) actions = Variable(actions) pred, _ = planning.train_policy_net_mper(model, inputs, targets, dropout=opt.p_dropout, model_type=model_type) loss_i, loss_s, loss_c_, loss_p = compute_loss(targets, pred) # proximity_cost, lane_cost = pred[2][:, :, 0], pred[2][:, :, 1] # proximity_cost = proximity_cost * Variable(gamma_mask) # lane_cost = lane_cost * Variable(gamma_mask) # loss_c = proximity_cost.mean() + opt.lambda_lane * lane_cost.mean() loss_policy = loss_i + loss_s + opt.lambda_h * loss_p if opt.loss_c == 1: loss_policy += loss_c_ if not math.isnan(loss_policy.item()): loss_policy.backward() torch.nn.utils.clip_grad_norm(model.policy_net.parameters(), opt.grad_clip) optimizer.step() total_loss_i += loss_i.item() total_loss_s += loss_s.item() total_loss_p += loss_p.item() total_loss_policy += loss_policy.item() n_updates += 1 else: print('warning, NaN') del inputs, actions, targets, pred total_loss_i /= n_updates total_loss_s /= n_updates total_loss_c /= n_updates total_loss_policy /= n_updates total_loss_p /= n_updates return total_loss_i, total_loss_s, total_loss_c, total_loss_policy, total_loss_p