def cal_l2_losses(pred_traj_gt, pred_traj_gt_rel, pred_traj_fake, pred_traj_fake_rel, loss_mask): g_l2_loss_abs = l2_loss(pred_traj_fake, pred_traj_gt, loss_mask, mode='sum') g_l2_loss_rel = l2_loss(pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='sum') return g_l2_loss_abs, g_l2_loss_rel
def generator_step(args, batch, generator, discriminator, g_loss_fn, optimizer_g): batch = [tensor.cuda() for tensor in batch] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end) = batch losses = {} loss = torch.zeros(1).to(pred_traj_gt) g_l2_loss_rel = [] loss_mask = loss_mask[:, args.obs_len:] for _ in range(args.best_k): generator_out = generator(obs_traj, obs_traj_rel, seq_start_end) pred_traj_fake_rel = generator_out pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1]) if args.l2_loss_weight > 0: g_l2_loss_rel.append(args.l2_loss_weight * l2_loss( pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw')) g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) if args.l2_loss_weight > 0: g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1) print("g_l2_loss_rel : ", g_l2_loss_rel.shape) for start, end in seq_start_end.data: # print("loss mask: ", loss_mask) print("seq_start_end: ", seq_start_end) print("start: ", start, "end: ", end) # print("loss mask[start:end]: ", loss_mask[start:end]) _g_l2_loss_rel = g_l2_loss_rel[start:end] print("_g_l2_loss_rel __1__ : ", _g_l2_loss_rel.shape) _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0) _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum( loss_mask[start:end]) g_l2_loss_sum_rel += _g_l2_loss_rel losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item() loss += g_l2_loss_sum_rel traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0) traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0) scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end) discriminator_loss = g_loss_fn(scores_fake) loss += discriminator_loss losses['G_discriminator_loss'] = discriminator_loss.item() losses['G_total_loss'] = loss.item() optimizer_g.zero_grad() loss.backward() if args.clipping_threshold_g > 0: nn.utils.clip_grad_norm_(generator.parameters(), args.clipping_threshold_g) optimizer_g.step() return losses
def generator_step(args, batch, generator, optimizer_g, epoch): batch = [tensor.cuda() for tensor in batch] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end) = batch losses = {} loss = torch.zeros(1).to(pred_traj_gt) g_l2_loss_rel = [] loss_mask = loss_mask[:, args.obs_len:] pred_traj_fake_rel = generator(obs_traj, obs_traj_rel, seq_start_end, epoch) if args.l2_loss_weight > 0: g_l2_loss_rel.append(args.l2_loss_weight * l2_loss( pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw')) g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) if args.l2_loss_weight > 0: g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1) for start, end in seq_start_end.data: _g_l2_loss_rel = g_l2_loss_rel[start:end] _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0) _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum( loss_mask[start:end]) g_l2_loss_sum_rel += _g_l2_loss_rel losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item() loss += g_l2_loss_sum_rel loss.backward() optimizer_g.step() optimizer_g.zero_grad() return losses
def generator_step(args, batch, generator, discriminator, g_loss_fn, optimizer_g): batch = [tensor.cuda() for tensor in batch] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end, _) = batch losses = {} loss = torch.zeros(1).to(pred_traj_gt) g_l2_loss_rel = [] loss_mask = loss_mask[:, args.obs_len:] for _ in range(args.best_k): generator_out = generator(obs_traj, obs_traj_rel, seq_start_end) pred_traj_fake_rel = generator_out pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1]) if args.l2_loss_weight > 0: g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw')) g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) if args.l2_loss_weight > 0: g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1) for start, end in seq_start_end.data: _g_l2_loss_rel = g_l2_loss_rel[start:end] # obj x sample _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0) # sample valid_num_frames = torch.sum(loss_mask[start:end]) # avoid all objects have masked GT in this time window if valid_num_frames == 0: continue _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / valid_num_frames g_l2_loss_sum_rel += _g_l2_loss_rel losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item() loss += g_l2_loss_sum_rel traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0) traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0) scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end) discriminator_loss = g_loss_fn(scores_fake) loss += discriminator_loss losses['G_discriminator_loss'] = discriminator_loss.item() losses['G_total_loss'] = loss.item() optimizer_g.zero_grad() loss.backward() if args.clipping_threshold_g > 0: nn.utils.clip_grad_norm_( generator.parameters(), args.clipping_threshold_g ) optimizer_g.step() return losses
def regressor_step(args, batch, regressor, optimizer_r): batch = [tensor.cuda() for tensor in batch] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, obs_team_vec, obs_pos_vec, pred_team_vec, pred_pos_vec, non_linear_ped, loss_mask, seq_start_end) = batch losses = {} loss = torch.zeros(1).to(pred_traj_gt) r_l2_loss_rel = [] loss_mask = loss_mask[:, args.obs_len:] for _ in range(args.best_k): generator_out = regressor(obs_traj, obs_traj_rel, seq_start_end) pred_traj_fake_rel = generator_out if args.l2_loss_weight > 0: r_l2_loss_rel.append(args.l2_loss_weight * l2_loss( pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode=args.l2_loss_mode # default:"raw" )) r_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) if args.l2_loss_weight > 0: r_l2_loss_rel = torch.stack(r_l2_loss_rel, dim=1) for start, end in seq_start_end.data: _r_l2_loss_rel = r_l2_loss_rel[start:end] _r_l2_loss_rel = torch.sum(_r_l2_loss_rel, dim=0) _r_l2_loss_rel = torch.min(_r_l2_loss_rel) / torch.sum( loss_mask[start:end]) r_l2_loss_sum_rel += _r_l2_loss_rel losses['R_l2_loss_rel'] = r_l2_loss_sum_rel.item() loss += r_l2_loss_sum_rel losses['R_total_loss'] = loss.item() optimizer_r.zero_grad() loss.backward() if args.clipping_threshold_g > 0: nn.utils.clip_grad_norm_(regressor.parameters(), args.clipping_threshold_g) optimizer_r.step() return losses
def generator_step( args, batch, generatorSO, generatorST, discriminator, netH, g_loss_fn, optimizer_gso, optimizer_gst, optimizer_h ): batch = [tensor.cuda() for tensor in batch] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end) = batch losses = {} loss = torch.zeros(1).to(pred_traj_gt) g_l2_loss_rel = [] g_mi_loss_rel = [] loss_mask = loss_mask[:, args.obs_len:] label = torch.zeros((seq_start_end[-1][-1]-seq_start_end[0][0]).item()).cuda() label[: int(len(label)/2)].data.fill_(1) for _ in range(args.best_k): noise_input, noise_shape = generatorSO(obs_traj, obs_traj_rel, seq_start_end) z_noise = MALA_corrected_sampler(generatorST, discriminator, args, noise_shape, noise_input, seq_start_end, obs_traj, obs_traj_rel) decoder_h = torch.cat([noise_input, z_noise], dim=1) decoder_h = torch.unsqueeze(decoder_h, 0) generator_out = generatorST(decoder_h, seq_start_end, obs_traj, obs_traj_rel) pred_traj_fake_rel = generator_out pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1]) if args.l2_loss_weight > 0: g_l2_loss_rel.append(args.l2_loss_weight * l2_loss( pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw')) z_noise_bar = z_noise[torch.randperm(len(z_noise))] concat_x_pred = torch.cat([pred_traj_fake, pred_traj_fake], 0) concat_z_noise = torch.cat([z_noise, z_noise_bar], -1) mi_estimate = nn.BCEWithLogitsLoss()(netH(concat_x_pred.permute(1,0,2).reshape(len(concat_z_noise),-1).squeeze(), concat_z_noise).squeeze(), label) g_mi_loss_rel.append(mi_estimate) g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) if args.l2_loss_weight > 0: g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1) for start, end in seq_start_end.data: _g_l2_loss_rel = g_l2_loss_rel[start:end] _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0) _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum( loss_mask[start:end]) g_l2_loss_sum_rel += _g_l2_loss_rel losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item() loss += (g_l2_loss_sum_rel * (1 - args.lamdba_l2)) g_mi_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) g_mi_loss_rel = torch.stack(g_mi_loss_rel, dim=0) g_mi_loss_sum_rel = torch.sum(g_mi_loss_rel) losses['G_mi_loss_rel'] = g_mi_loss_sum_rel.item() loss += (g_mi_loss_sum_rel * args.lamdba_l2) traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0) traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0) scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end) discriminator_loss = g_loss_fn(scores_fake) loss += (discriminator_loss * args.lamdba_l2) losses['G_discriminator_loss'] = discriminator_loss.item() losses['G_total_loss'] = loss.item() optimizer_gso.zero_grad() optimizer_gst.zero_grad() optimizer_h.zero_grad() loss.backward() if args.clipping_threshold_g > 0: nn.utils.clip_grad_norm_( generatorSO.parameters(), generatorST.parameters(), args.clipping_threshold_g ) optimizer_gso.step() optimizer_gst.step() optimizer_h.step() return losses
def generator_step(args, batch, generator, discriminator, g_loss_fn, optimizer_g): batch = [tensor.cuda() for tensor in batch] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end) = batch losses = {} loss = torch.zeros(1).to(pred_traj_gt) g_l2_loss_rel = [] loss_mask = loss_mask[:, args.obs_len:] for _ in range(args.best_k): # obs_traj_rel.requires_grad_(True) # obs_traj.requires_grad_(True) generator_out = generator(obs_traj, obs_traj_rel, seq_start_end) pred_traj_fake_rel = generator_out # print("Evaluating Attention") # # get_attn(pred_traj_fake_rel, obs_traj_rel) # print("Getting Attention") # pred_traj_zero = pred_traj_fake_rel[:,0,:] # obs_traj_zero = obs_traj_rel[:,0,:] # print(pred_traj_zero.shape) # print(obs_traj_zero.shape) # print(pred_traj_zero[1,1]) # print(pred_traj_zero.requires_grad) # print(obs_traj_zero.requires_grad) # obs_traj_zero_grad = torch.autograd.grad(pred_traj_zero[1,1], obs_traj_zero, allow_unused=True) # print(obs_traj_zero_grad) pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1]) if args.l2_loss_weight > 0: g_l2_loss_rel.append(args.l2_loss_weight * l2_loss( pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw')) g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) if args.l2_loss_weight > 0: g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1) for start, end in seq_start_end.data: _g_l2_loss_rel = g_l2_loss_rel[start:end] _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0) _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum( loss_mask[start:end]) g_l2_loss_sum_rel += _g_l2_loss_rel losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item() loss += 2 * args.l2_loss_weight * g_l2_loss_sum_rel traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0) traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0) scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end) discriminator_loss = g_loss_fn(scores_fake) loss += 2 * (1 - args.l2_loss_weight) * discriminator_loss losses['G_discriminator_loss'] = discriminator_loss.item() losses['G_total_loss'] = loss.item() optimizer_g.zero_grad() loss.backward() if args.clipping_threshold_g > 0: nn.utils.clip_grad_norm_(generator.parameters(), args.clipping_threshold_g) optimizer_g.step() return losses
def generator_step(args, batch, generator, discriminator, g_loss_fn, optimizer_g): if (args.use_gpu): batch = [tensor.cuda() for tensor in batch] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end) = batch losses = {} loss = torch.zeros(1).to(pred_traj_gt) g_l2_loss_rel = [] loss_mask = loss_mask[:, args.obs_len:] for _ in range(args.best_k): generator_out = generator(obs_traj, obs_traj_rel, seq_start_end) pred_traj_fake_rel = generator_out pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1]) if args.l2_loss_weight > 0: g_l2_loss_rel.append(args.l2_loss_weight * l2_loss( pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw')) g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) if args.l2_loss_weight > 0: g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1) for start, end in seq_start_end.data: _g_l2_loss_rel = g_l2_loss_rel[start:end] _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0) _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum( loss_mask[start:end]) g_l2_loss_sum_rel += _g_l2_loss_rel losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item() loss += g_l2_loss_sum_rel if args.cosine_loss_weight > 0: # TODO : what to do when curvature loss needs to be included? # need to obtain the curvature for all data points in the batch. And save their average to curvature_loss tensor.. # given pred_traj_fake_rel, and last two inputs(may be not needed for first iteration). pred_traj_gt_rel(not needed) # print("COSINE: pred_traj_fake_rel.shape", pred_traj_fake_rel.shape) # print("pred_traj_fake_rel", pred_traj_fake_rel) # just use this tensor for now, then you can add other stuff later. #pred_traj_fake_rel_shifted = torch.roll(pred_traj_fake_rel, 1, [1]) pred_traj_fake_rel_shifted = torch.cat( (pred_traj_fake_rel[1:, :, :], pred_traj_fake_rel[:1, :, :]), dim=0) # simulating rolling # print("pred_traj_fake_rel_shifted", pred_traj_fake_rel_shifted) cosfunc = nn.CosineSimilarity(dim=2, eps=1e-6) similarity = cosfunc(pred_traj_fake_rel, pred_traj_fake_rel_shifted) # print("similarity", similarity) cosine_loss = torch.mean(similarity[1:, :]) losses['G_cosine_loss'] = cosine_loss.item() print("cosine_loss", cosine_loss) loss += -args.cosine_loss_weight * cosine_loss # sys.exit() curv_debug = False print("loss before adding curvature", loss) if args.curvature_loss_weight > 0: print("CURVATURE: pred_traj_fake_rel.shape", pred_traj_fake_rel.shape) # print("pred_traj_fake_rel", pred_traj_fake_rel) # just use this tensor for now, then you can add other stuff later. dists = torch.norm(pred_traj_fake_rel, p=2, dim=2) # print("dists", dists) print("dists.shape", dists.shape) # get 3 sets of points aas, bbs and ccs corresponding to 3 subsequent points aas = torch.cat((obs_traj[-2:, :, :], pred_traj_fake[:-2, :, :]), dim=0) print("aas.dtype", aas.dtype) bbs = torch.cat((obs_traj[-1:, :, :], pred_traj_fake[:-1, :, :]), dim=0) ccs = pred_traj_fake # print("obs_traj[-3:,:3,:]", obs_traj[-3:,:3,:]) if (curv_debug): print("aas.shape", aas.shape) print("bbs.shape", bbs.shape) print("ccs.shape", ccs.shape) print("aas:\n", aas) print("bbs:\n", bbs) print("ccs:\n", ccs) cside = torch.norm(aas - bbs, p=2, dim=2) bside = torch.norm(aas - ccs, p=2, dim=2) aside = torch.norm(ccs - bbs, p=2, dim=2) # print("cside.shape", cside.shape) s = (aside + bside + cside) / 2 areas = torch.sqrt(s * (s - aside) * (s - bside) * (s - cside)) if curv_debug: onne = torch.ones(areas.shape).cuda() zeero = torch.zeros(areas.shape).cuda() print("onne.shape", onne.shape) print("zeero.shape", zeero.shape) debug_ars = torch.where(areas < 0.00001, onne, zeero) areas = torch.where(areas < 0.0001, zeero, areas) debug_ars_sum = torch.sum(debug_ars) print("debug_ars_sum: ", debug_ars_sum) # print("areas : ", areas) curvatures = areas / aside / bside / cside exp = (curvatures != curvatures) curvatures[exp] = 0 # curvatures = curvatures * exp # print("curvatures.shape", curvatures.shape) # print("aside", aside) # print("bside", bside) # print("cside", cside) # print("areas", areas) # print("curvatures", curvatures) # curvatures[curvatures != curvatures] = 0 curvature_loss = torch.mean(curvatures) # curvature_loss[curvature_lFoss != curvature_loss] = 0 #setting nan as zero print("curvature_loss", curvature_loss) losses['G_curvature_loss'] = curvature_loss.item() loss += args.curvature_loss_weight * curvature_loss # sys.exit() if (loss != loss): print("[id]", "is where you see NaN first") print("loss", loss) print("curvature_loss Nan", curvature_loss) print("\n\n==================NAN================\n\n") nan_loc = 0 nan_id = 7 print("\n===== curvatures[nan_loc,nan_id:nan_id+5]\n", curvatures[nan_loc, nan_id:nan_id + 5]) print("\n===== areas[nan_loc,nan_id:nan_id+5]\n", areas[nan_loc, nan_id:nan_id + 5]) print("\n===== aside[nan_loc,nan_id:nan_id+5]\n", aside[nan_loc, nan_id:nan_id + 5]) print("\n===== bside[nan_loc,nan_id:nan_id+5]\n", bside[nan_loc, nan_id:nan_id + 5]) print("\n===== cside[nan_loc,nan_id:nan_id+5]\n", cside[nan_loc, nan_id:nan_id + 5]) print("aas.shape", aas.shape) print("bbs.shape", bbs.shape) print("ccs.shape", ccs.shape) print("aas:\n", aas[nan_loc, nan_id:nan_id + 5]) print("bbs:\n", bbs[nan_loc, nan_id:nan_id + 5]) print("ccs:\n", ccs[nan_loc, nan_id:nan_id + 5]) exp = (curvatures == curvatures) print("obs_traj", obs_traj[:, 6:11, :]) print("pred_traj_fake", pred_traj_fake[:, 6:11, :]) print("obs_traj.shape", obs_traj.shape) print("pred_traj_fake.shape", pred_traj_fake.shape) print("exp", exp[:, 6:11]) # print ("exp", exp[nan_loc,nan_id :nan_id+5]) # print("prevgrads", prevgrads) # print("generator weights: ", [x.data for x in generator.decoder.parameters()]) # print("generator grad vals: ", [x.grad for x in generator.decoder.parameters()]) sys.exit() traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0) traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0) scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end) discriminator_loss = g_loss_fn(scores_fake) if args.use_discriminator: print("using discriminator loss") loss += discriminator_loss losses['G_discriminator_loss'] = discriminator_loss.item() losses['G_total_loss'] = loss.item() optimizer_g.zero_grad() if curv_debug: print("generator grad vals BEFORE: ", [x.grad for x in generator.decoder.parameters()]) loss.backward(retain_graph=True) for x in generator.decoder.parameters(): nanmask = (x.grad != x.grad) # print(nanmask) x.grad.data[nanmask] = 0 if torch.isnan(x.grad).any(): print("NaN generated during backprop") print(x.grad) sys.exit() if curv_debug: print("generator grad vals: ", [x.grad for x in generator.decoder.parameters()]) if args.clipping_threshold_g > 0: nn.utils.clip_grad_norm_(generator.parameters(), args.clipping_threshold_g) optimizer_g.step() #print some grads here to see what's going wrong: if curv_debug: print("autograd.grad(loss, curvature_loss)", autograd.grad(loss, curvature_loss)) # print("autograd.grad(loss, areas)", autograd.grad(loss, areas)) # print("autograd.grad(loss, aside)", autograd.grad(loss, aside)) # print("autograd.grad(curvature_loss, ccs)", autograd.grad(curvature_loss, ccs)) print("autograd.grad(curvature_loss, pred_traj_fake_rel)", autograd.grad(curvature_loss, pred_traj_fake_rel)) # sys.exit() return losses
def generator_step( args, batch, generator, discriminator, g_loss_fn, optimizer_g ): if(args.use_gpu): batch = [tensor.cuda() for tensor in batch] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end) = batch losses = {} loss = torch.zeros(1).to(pred_traj_gt) g_l2_loss_rel = [] loss_mask = loss_mask[:, args.obs_len:] for _ in range(args.best_k): generator_out = generator(obs_traj, obs_traj_rel, seq_start_end) pred_traj_fake_rel = generator_out pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1]) if args.l2_loss_weight > 0: g_l2_loss_rel.append(args.l2_loss_weight * l2_loss( pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw')) g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt) if args.l2_loss_weight > 0: g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1) for start, end in seq_start_end.data: _g_l2_loss_rel = g_l2_loss_rel[start:end] _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0) _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum( loss_mask[start:end]) g_l2_loss_sum_rel += _g_l2_loss_rel losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item() loss += g_l2_loss_sum_rel if args.cosine_loss_weight > 0: # TODO : what to do when curvature loss needs to be included? # need to obtain the curvature for all data points in the batch. And save their average to curvature_loss tensor.. # given pred_traj_fake_rel, and last two inputs(may be not needed for first iteration). pred_traj_gt_rel(not needed) # print("COSINE: pred_traj_fake_rel.shape", pred_traj_fake_rel.shape) # print("pred_traj_fake_rel", pred_traj_fake_rel) # just use this tensor for now, then you can add other stuff later. # pred_traj_fake_rel_shifted = torch.roll(pred_traj_fake_rel, 1, [1]) pred_traj_fake_rel_shifted = torch.cat( (pred_traj_fake_rel[1:, :, :], pred_traj_fake_rel[:1, :, :]), dim=0) # simulating rolling # print("pred_traj_fake_rel_shifted", pred_traj_fake_rel_shifted) cosfunc = nn.CosineSimilarity(dim=2, eps=1e-6) similarity = cosfunc(pred_traj_fake_rel, pred_traj_fake_rel_shifted) # print("similarity", similarity) cosine_loss = torch.sum(similarity[1:,:]) print("cosine_loss", cosine_loss) loss += -args.cosine_loss_weight * cosine_loss # sys.exit() traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0) traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0) scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end) discriminator_loss = g_loss_fn(scores_fake) #loss += discriminator_loss losses['G_discriminator_loss'] = discriminator_loss.item() losses['G_total_loss'] = loss.item() optimizer_g.zero_grad() loss.backward() if args.clipping_threshold_g > 0: nn.utils.clip_grad_norm_( generator.parameters(), args.clipping_threshold_g ) optimizer_g.step() return losses