def eval_step(self, batch, best_k=10): ade_sum, fde_sum = [], [] ade_sum_pixel, fde_sum_pixel = [], [] # get pixel ratios ratios = [] for img in batch["scene_img"]: ratios.append(torch.tensor(img["ratio"])) ratios = torch.stack(ratios).to(self.device) batch = get_batch_k(batch, best_k) batch_size = batch["size"] out = self.test(batch) if self.plot_val: self.plot_val = False self.visualize_results(batch, out) # FDE and ADE metrics ade_error = cal_ade( batch["gt_xy"], out["out_xy"], mode='raw' ) fde_error = cal_fde( batch["gt_xy"], out["out_xy"], mode='raw' ) ade_error = ade_error.view(best_k, batch_size) fde_error = fde_error.view(best_k, batch_size) for idx, (start, end) in enumerate(batch["seq_start_end"]): ade_error_sum = torch.sum(ade_error[:, start:end], dim=1) fde_error_sum = torch.sum(fde_error[:, start:end], dim=1) ade_sum_scene, id_scene = ade_error_sum.min(dim=0, keepdims=True) fde_sum_scene, _ = fde_error_sum.min(dim=0, keepdims=True) ade_sum.append(ade_sum_scene / (self.hparams.pred_len * (end - start))) fde_sum.append(fde_sum_scene / (end - start)) ade_sum_pixel.append(ade_sum_scene / (self.hparams.pred_len * (end - start) * ratios[idx])) fde_sum_pixel.append(fde_sum_scene / (ratios[idx] * (end - start))) # compute Mode Caughts metrics fde_min, _ = fde_error.min(dim=0) modes_caught = (fde_min < self.hparams.mode_dist_threshold).float() if any(batch["occupancy"]): wall_crashes = crashIntoWall(out["out_xy"].cpu(), batch["occupancy"]) else: wall_crashes = [0] return {"ade": ade_sum, "fde": fde_sum, "ade_pixel": ade_sum_pixel, "fde_pixel": fde_sum_pixel, "wall_crashes": wall_crashes, "modes_caught": modes_caught}
def generator_step(self, batch): """Generator optimization step. Args: batch: Batch from the data loader. Returns: discriminator loss on fake norm loss on trajectory kl loss """ # init loss and loss dict tqdm_dict = {} total_loss = 0. ade_sum, fde_sum = [], [] ade_sum_pixel, fde_sum_pixel = [], [] # get k times batch batch = get_batch_k(batch, self.hparams.best_k) batch_size = batch["size"].item() generator_out = self.generator(batch) if self.hparams.absolute: l2 = self.loss_fns["L2"]( batch["gt_xy"], generator_out["out_xy"], mode='average', type="mse") else: l2 = self.loss_fns["L2"]( batch["gt_dxdy"], generator_out["out_dxdy"], mode='raw', type="mse") ade_error = cal_ade( batch["gt_xy"], generator_out["out_xy"], mode='raw' ) fde_error = cal_fde( batch["gt_xy"], generator_out["out_xy"], mode='raw' ) ade_error = ade_error.view(self.hparams.best_k, batch_size) fde_error = fde_error.view(self.hparams.best_k, batch_size) # get pixel ratios ratios = [] for img in batch["scene_img"]: ratios.append(torch.tensor(img["ratio"])) ratios = torch.stack(ratios).to(self.device) for idx, (start, end) in enumerate(batch["seq_start_end"]): ade_error_sum = torch.sum(ade_error[:, start:end], dim=1) fde_error_sum = torch.sum(fde_error[:, start:end], dim=1) ade_sum_scene, id_scene = ade_error_sum.min(dim=0, keepdims=True) fde_sum_scene, _ = fde_error_sum.min(dim=0, keepdims=True) ade_sum.append(ade_sum_scene / (self.hparams.pred_len * (end - start))) fde_sum.append(fde_sum_scene / (end - start)) ade_sum_pixel.append(ade_sum_scene / (self.hparams.pred_len * (end - start) * ratios[idx])) fde_sum_pixel.append(fde_sum_scene / (ratios[idx] * (end - start))) tqdm_dict["ADE_train"] = torch.mean(torch.stack(ade_sum)) tqdm_dict["FDE_train"] = torch.mean(torch.stack(fde_sum)) tqdm_dict["ADE_pixel_train"] = torch.mean(torch.stack(ade_sum_pixel)) tqdm_dict["FDE_pixel_train"] = torch.mean(torch.stack(fde_sum_pixel)) # count trajectories crashing into the 'wall' if any(batch["occupancy"]): wall_crashes = crashIntoWall(generator_out["out_xy"].detach().cpu(), batch["occupancy"]) else: wall_crashes = [0] tqdm_dict["feasibility_train"] = torch.tensor(1 - np.mean(wall_crashes)) l2 = l2.view(self.hparams.best_k, -1) loss_l2, _ = l2.min(dim=0, keepdim=True) loss_l2 = torch.mean(loss_l2) loss_l2 = self.loss_weights["L2"]*loss_l2 tqdm_dict["L2_train"] = loss_l2 total_loss+=loss_l2 if self.generator.global_vis_type == "goal": target_reshaped = batch["prob_mask"][:batch_size].view(batch_size, -1) output_reshaped = generator_out["y_scores"][:batch_size].view(batch_size, -1) _, targets = target_reshaped.max(dim=1) loss_gce = self.loss_weights["GCE"] * self.loss_fns["GCE"](output_reshaped, targets) total_loss+=loss_gce tqdm_dict["GCE_train"] = loss_gce final_end = torch.sum(generator_out["out_dxdy"], dim=0, keepdim=True) final_end_gt = torch.sum(batch["gt_dxdy"], dim=0, keepdim=True) final_pos = generator_out["final_pos"] goal_error = self.loss_fns["G"](final_pos.detach(), final_end_gt) goal_error = goal_error.view(self.hparams.best_k, -1) _, id_min = goal_error.min(dim=0, keepdim=False) # id_min*=torch.range(0, len(id_min))*10 final_pos = final_pos.view(self.hparams.best_k, batch["size"], -1) final_end = final_end.view(self.hparams.best_k, batch["size"], -1) final_pos = torch.cat([final_pos[id_min[k], k].unsqueeze(0) for k in range(final_pos.size(1))]).unsqueeze(0) final_end = torch.cat([final_end[id_min[k], k].unsqueeze(0) for k in range(final_end.size(1))]).unsqueeze(0) loss_G = self.loss_weights["G"] * torch.mean( self.loss_fns["G"](final_pos.detach(), final_end, mode='raw')) total_loss+=loss_G tqdm_dict["G_train"] = loss_G traj_fake = generator_out["out_xy"][:, :batch_size] traj_fake_rel = generator_out["out_dxdy"][:, :batch_size] if self.generator.rm_vis_type == "attention": image_patches = generator_out["image_patches"][:, :batch_size] else: image_patches = None fake_scores = self.discriminator(in_xy=batch["in_xy"][:, :batch_size], in_dxdy=batch["in_dxdy"][:, :batch_size], out_xy=traj_fake, out_dxdy=traj_fake_rel, images_patches=image_patches) loss_adv = self.loss_weights["ADV"] * self.loss_fns["ADV"](fake_scores, True).clamp(min=0) total_loss+=loss_adv tqdm_dict["ADV_train"] = loss_adv tqdm_dict["G_loss"] = total_loss for key, loss in tqdm_dict.items(): self.logger.experiment.add_scalar('train/{}'.format(key), loss, self.global_step) return {"loss": total_loss}