Example #1
0
	def eval_step(self, batch, best_k=10):
		ade_sum, fde_sum = [], []
		ade_sum_pixel, fde_sum_pixel = [], []

		# get pixel ratios
		ratios = []
		for img in batch["scene_img"]:
			ratios.append(torch.tensor(img["ratio"]))
		ratios = torch.stack(ratios).to(self.device)

		batch = get_batch_k(batch, best_k)
		batch_size = batch["size"]

		out = self.test(batch)

		if self.plot_val:
			self.plot_val = False
			self.visualize_results(batch, out)


		# FDE and ADE metrics
		ade_error = cal_ade(
			batch["gt_xy"], out["out_xy"], mode='raw'
		)

		fde_error = cal_fde(
			batch["gt_xy"], out["out_xy"], mode='raw'
		)

		ade_error = ade_error.view(best_k, batch_size)

		fde_error = fde_error.view(best_k, batch_size)

		for idx, (start, end) in enumerate(batch["seq_start_end"]):
			ade_error_sum = torch.sum(ade_error[:, start:end], dim=1)
			fde_error_sum = torch.sum(fde_error[:, start:end], dim=1)

			ade_sum_scene, id_scene = ade_error_sum.min(dim=0, keepdims=True)
			fde_sum_scene, _ = fde_error_sum.min(dim=0, keepdims=True)

			ade_sum.append(ade_sum_scene / (self.hparams.pred_len * (end - start)))
			fde_sum.append(fde_sum_scene / (end - start))

			ade_sum_pixel.append(ade_sum_scene / (self.hparams.pred_len * (end - start) * ratios[idx]))
			fde_sum_pixel.append(fde_sum_scene / (ratios[idx] * (end - start)))


		# compute Mode Caughts metrics
		fde_min, _ = fde_error.min(dim=0)
		modes_caught = (fde_min < self.hparams.mode_dist_threshold).float()

		if any(batch["occupancy"]):

			wall_crashes = crashIntoWall(out["out_xy"].cpu(), batch["occupancy"])
		else:
			wall_crashes = [0]
		return {"ade": ade_sum, "fde": fde_sum, "ade_pixel": ade_sum_pixel, "fde_pixel": fde_sum_pixel,
				"wall_crashes": wall_crashes, "modes_caught": modes_caught}
Example #2
0
	def generator_step(self, batch):

		"""Generator optimization step.
		Args:
		    batch: Batch from the data loader.

		Returns:
		    discriminator loss on fake
		    norm loss on trajectory
		    kl loss
		"""

		# init loss and loss dict
		tqdm_dict = {}
		total_loss = 0.
		ade_sum, fde_sum = [], []
		ade_sum_pixel, fde_sum_pixel = [], []
		# get k times batch
		batch = get_batch_k(batch, self.hparams.best_k)

		batch_size = batch["size"].item()

		generator_out = self.generator(batch)




		if self.hparams.absolute:
			l2 = self.loss_fns["L2"](
				batch["gt_xy"],
				generator_out["out_xy"],
				mode='average',
				type="mse")
		else:
			l2 = self.loss_fns["L2"](
				batch["gt_dxdy"],
				generator_out["out_dxdy"],
				mode='raw',
				type="mse")

		ade_error = cal_ade(
			batch["gt_xy"], generator_out["out_xy"], mode='raw'
		)

		fde_error = cal_fde(
			batch["gt_xy"], generator_out["out_xy"], mode='raw'
		)

		ade_error = ade_error.view(self.hparams.best_k, batch_size)

		fde_error = fde_error.view(self.hparams.best_k, batch_size)

		# get pixel ratios
		ratios = []
		for img in batch["scene_img"]:
			ratios.append(torch.tensor(img["ratio"]))
		ratios = torch.stack(ratios).to(self.device)

		for idx, (start, end) in enumerate(batch["seq_start_end"]):
			ade_error_sum = torch.sum(ade_error[:, start:end], dim=1)
			fde_error_sum = torch.sum(fde_error[:, start:end], dim=1)

			ade_sum_scene, id_scene = ade_error_sum.min(dim=0, keepdims=True)
			fde_sum_scene, _ = fde_error_sum.min(dim=0, keepdims=True)

			ade_sum.append(ade_sum_scene / (self.hparams.pred_len * (end - start)))
			fde_sum.append(fde_sum_scene / (end - start))

			ade_sum_pixel.append(ade_sum_scene / (self.hparams.pred_len * (end - start) * ratios[idx]))
			fde_sum_pixel.append(fde_sum_scene / (ratios[idx] * (end - start)))

		tqdm_dict["ADE_train"] = torch.mean(torch.stack(ade_sum))
		tqdm_dict["FDE_train"] = torch.mean(torch.stack(fde_sum))


		tqdm_dict["ADE_pixel_train"] = torch.mean(torch.stack(ade_sum_pixel))
		tqdm_dict["FDE_pixel_train"] = torch.mean(torch.stack(fde_sum_pixel))
		# count trajectories crashing into the 'wall'
		if any(batch["occupancy"]):
			wall_crashes = crashIntoWall(generator_out["out_xy"].detach().cpu(), batch["occupancy"])
		else:
			wall_crashes = [0]
		tqdm_dict["feasibility_train"] = torch.tensor(1 - np.mean(wall_crashes))
		l2 = l2.view(self.hparams.best_k, -1)

		loss_l2, _ = l2.min(dim=0, keepdim=True)
		loss_l2 = torch.mean(loss_l2)

		loss_l2 = self.loss_weights["L2"]*loss_l2
		tqdm_dict["L2_train"] = loss_l2
		total_loss+=loss_l2
		if self.generator.global_vis_type == "goal":
			target_reshaped = batch["prob_mask"][:batch_size].view(batch_size, -1)
			output_reshaped = generator_out["y_scores"][:batch_size].view(batch_size, -1)

			_, targets = target_reshaped.max(dim=1)

			loss_gce = self.loss_weights["GCE"] * self.loss_fns["GCE"](output_reshaped, targets)

			total_loss+=loss_gce
			tqdm_dict["GCE_train"] = loss_gce


			final_end = torch.sum(generator_out["out_dxdy"], dim=0, keepdim=True)
			final_end_gt = torch.sum(batch["gt_dxdy"], dim=0, keepdim=True)

			final_pos = generator_out["final_pos"]

			goal_error = self.loss_fns["G"](final_pos.detach(), final_end_gt)
			goal_error = goal_error.view(self.hparams.best_k, -1)
			_, id_min = goal_error.min(dim=0, keepdim=False)
			# id_min*=torch.range(0, len(id_min))*10

			final_pos = final_pos.view(self.hparams.best_k, batch["size"], -1)
			final_end = final_end.view(self.hparams.best_k, batch["size"], -1)

			final_pos = torch.cat([final_pos[id_min[k], k].unsqueeze(0) for k in range(final_pos.size(1))]).unsqueeze(0)
			final_end = torch.cat([final_end[id_min[k], k].unsqueeze(0) for k in range(final_end.size(1))]).unsqueeze(0)

			loss_G  = self.loss_weights["G"] * torch.mean( self.loss_fns["G"](final_pos.detach(), final_end, mode='raw'))

			total_loss+=loss_G
			tqdm_dict["G_train"] = loss_G


		traj_fake = generator_out["out_xy"][:, :batch_size]
		traj_fake_rel = generator_out["out_dxdy"][:, :batch_size]

		if self.generator.rm_vis_type == "attention":
			image_patches = generator_out["image_patches"][:, :batch_size]
		else:
			image_patches = None

		fake_scores = self.discriminator(in_xy=batch["in_xy"][:, :batch_size],
										  in_dxdy=batch["in_dxdy"][:, :batch_size],
										  out_xy=traj_fake,
										  out_dxdy=traj_fake_rel,
										  images_patches=image_patches)

		loss_adv = self.loss_weights["ADV"] * self.loss_fns["ADV"](fake_scores, True).clamp(min=0)

		total_loss+=loss_adv
		tqdm_dict["ADV_train"] = loss_adv

		tqdm_dict["G_loss"] = total_loss
		for key, loss in tqdm_dict.items():
			self.logger.experiment.add_scalar('train/{}'.format(key), loss, self.global_step)
		
		return {"loss": total_loss}