示例#1
0
def cal_l2_losses(pred_traj_gt, pred_traj_gt_rel, pred_traj_fake,
                  pred_traj_fake_rel, loss_mask):
    g_l2_loss_abs = l2_loss(pred_traj_fake,
                            pred_traj_gt,
                            loss_mask,
                            mode='sum')
    g_l2_loss_rel = l2_loss(pred_traj_fake_rel,
                            pred_traj_gt_rel,
                            loss_mask,
                            mode='sum')
    return g_l2_loss_abs, g_l2_loss_rel
示例#2
0
def generator_step(args, batch, generator, discriminator, g_loss_fn,
                   optimizer_g):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []

    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)

        pred_traj_fake_rel = generator_out
        pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw'))

    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)

    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)

        print("g_l2_loss_rel : ", g_l2_loss_rel.shape)
        for start, end in seq_start_end.data:
            #             print("loss mask: ", loss_mask)
            print("seq_start_end: ", seq_start_end)
            print("start: ", start, "end: ", end)
            #             print("loss mask[start:end]: ", loss_mask[start:end])
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            print("_g_l2_loss_rel __1__ : ", _g_l2_loss_rel.shape)
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(
                loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel

        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += g_l2_loss_sum_rel

    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    discriminator_loss = g_loss_fn(scores_fake)

    loss += discriminator_loss
    losses['G_discriminator_loss'] = discriminator_loss.item()
    losses['G_total_loss'] = loss.item()

    optimizer_g.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(generator.parameters(),
                                 args.clipping_threshold_g)
    optimizer_g.step()

    return losses
示例#3
0
def generator_step(args, batch, generator, optimizer_g, epoch):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []

    loss_mask = loss_mask[:, args.obs_len:]

    pred_traj_fake_rel = generator(obs_traj, obs_traj_rel, seq_start_end,
                                   epoch)

    if args.l2_loss_weight > 0:
        g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
            pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw'))

    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:

        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(
                loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += g_l2_loss_sum_rel

    loss.backward()
    optimizer_g.step()
    optimizer_g.zero_grad()

    return losses
示例#4
0
def generator_step(args, batch, generator, discriminator, g_loss_fn, optimizer_g):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end, _) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []
    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)
        pred_traj_fake_rel = generator_out
        pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw'))

    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]           # obj x sample
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)   # sample
            valid_num_frames = torch.sum(loss_mask[start:end])
            
            # avoid all objects have masked GT in this time window
            if valid_num_frames == 0: continue
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / valid_num_frames
            g_l2_loss_sum_rel += _g_l2_loss_rel
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += g_l2_loss_sum_rel

    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    discriminator_loss = g_loss_fn(scores_fake)

    loss += discriminator_loss
    losses['G_discriminator_loss'] = discriminator_loss.item()
    losses['G_total_loss'] = loss.item()

    optimizer_g.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(
            generator.parameters(), args.clipping_threshold_g
        )
    optimizer_g.step()

    return losses
示例#5
0
def regressor_step(args, batch, regressor, optimizer_r):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, obs_team_vec,
     obs_pos_vec, pred_team_vec, pred_pos_vec, non_linear_ped, loss_mask,
     seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    r_l2_loss_rel = []

    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        generator_out = regressor(obs_traj, obs_traj_rel, seq_start_end)

        pred_traj_fake_rel = generator_out

        if args.l2_loss_weight > 0:
            r_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel,
                pred_traj_gt_rel,
                loss_mask,
                mode=args.l2_loss_mode  # default:"raw"
            ))

    r_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        r_l2_loss_rel = torch.stack(r_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _r_l2_loss_rel = r_l2_loss_rel[start:end]
            _r_l2_loss_rel = torch.sum(_r_l2_loss_rel, dim=0)
            _r_l2_loss_rel = torch.min(_r_l2_loss_rel) / torch.sum(
                loss_mask[start:end])
            r_l2_loss_sum_rel += _r_l2_loss_rel
        losses['R_l2_loss_rel'] = r_l2_loss_sum_rel.item()
        loss += r_l2_loss_sum_rel

    losses['R_total_loss'] = loss.item()

    optimizer_r.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(regressor.parameters(),
                                 args.clipping_threshold_g)
    optimizer_r.step()

    return losses
示例#6
0
文件: train.py 项目: ACoTAI/CODE
def generator_step(
    args, batch, generatorSO, generatorST, discriminator, netH, g_loss_fn, optimizer_gso, optimizer_gst, optimizer_h
):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []
    g_mi_loss_rel = []
    loss_mask = loss_mask[:, args.obs_len:]
    label = torch.zeros((seq_start_end[-1][-1]-seq_start_end[0][0]).item()).cuda()
    label[: int(len(label)/2)].data.fill_(1)

    for _ in range(args.best_k):
        noise_input, noise_shape = generatorSO(obs_traj, obs_traj_rel, seq_start_end)
        z_noise = MALA_corrected_sampler(generatorST, discriminator, args, noise_shape, noise_input, seq_start_end, obs_traj, obs_traj_rel)
        decoder_h = torch.cat([noise_input, z_noise], dim=1)
        decoder_h = torch.unsqueeze(decoder_h, 0)
        generator_out = generatorST(decoder_h, seq_start_end, obs_traj, obs_traj_rel)
        pred_traj_fake_rel = generator_out
        pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])
        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel,
                pred_traj_gt_rel,
                loss_mask,
                mode='raw'))
        z_noise_bar = z_noise[torch.randperm(len(z_noise))]
        concat_x_pred = torch.cat([pred_traj_fake, pred_traj_fake], 0)
        concat_z_noise = torch.cat([z_noise, z_noise_bar], -1)
        mi_estimate = nn.BCEWithLogitsLoss()(netH(concat_x_pred.permute(1,0,2).reshape(len(concat_z_noise),-1).squeeze(), concat_z_noise).squeeze(), label)
        g_mi_loss_rel.append(mi_estimate)

    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(
                loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += (g_l2_loss_sum_rel * (1 - args.lamdba_l2))
    g_mi_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    g_mi_loss_rel = torch.stack(g_mi_loss_rel, dim=0)
    g_mi_loss_sum_rel = torch.sum(g_mi_loss_rel)
    losses['G_mi_loss_rel'] = g_mi_loss_sum_rel.item()
    loss += (g_mi_loss_sum_rel * args.lamdba_l2)
    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)
    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    discriminator_loss = g_loss_fn(scores_fake)
    loss += (discriminator_loss * args.lamdba_l2)
    losses['G_discriminator_loss'] = discriminator_loss.item()
    losses['G_total_loss'] = loss.item()
    optimizer_gso.zero_grad()
    optimizer_gst.zero_grad()
    optimizer_h.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(
            generatorSO.parameters(), generatorST.parameters(), args.clipping_threshold_g
        )
    optimizer_gso.step()
    optimizer_gst.step()
    optimizer_h.step()
    return losses
示例#7
0
def generator_step(args, batch, generator, discriminator, g_loss_fn,
                   optimizer_g):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []

    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        # obs_traj_rel.requires_grad_(True)
        # obs_traj.requires_grad_(True)
        generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)

        pred_traj_fake_rel = generator_out

        # print("Evaluating Attention")
        # # get_attn(pred_traj_fake_rel, obs_traj_rel)
        # print("Getting Attention")
        # pred_traj_zero = pred_traj_fake_rel[:,0,:]
        # obs_traj_zero = obs_traj_rel[:,0,:]
        # print(pred_traj_zero.shape)
        # print(obs_traj_zero.shape)
        # print(pred_traj_zero[1,1])
        # print(pred_traj_zero.requires_grad)
        # print(obs_traj_zero.requires_grad)
        # obs_traj_zero_grad = torch.autograd.grad(pred_traj_zero[1,1], obs_traj_zero, allow_unused=True)
        # print(obs_traj_zero_grad)

        pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw'))

    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(
                loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += 2 * args.l2_loss_weight * g_l2_loss_sum_rel

    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    discriminator_loss = g_loss_fn(scores_fake)

    loss += 2 * (1 - args.l2_loss_weight) * discriminator_loss
    losses['G_discriminator_loss'] = discriminator_loss.item()
    losses['G_total_loss'] = loss.item()

    optimizer_g.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(generator.parameters(),
                                 args.clipping_threshold_g)
    optimizer_g.step()

    return losses
示例#8
0
文件: train.py 项目: chetu181/sgan
def generator_step(args, batch, generator, discriminator, g_loss_fn,
                   optimizer_g):
    if (args.use_gpu):
        batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []

    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)

        pred_traj_fake_rel = generator_out
        pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='raw'))

    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(
                loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += g_l2_loss_sum_rel

    if args.cosine_loss_weight > 0:  # TODO : what to do when curvature loss needs to be included?
        # need to obtain the curvature for all data points in the batch. And save their average to curvature_loss tensor..
        # given pred_traj_fake_rel, and last two inputs(may be not needed for first iteration). pred_traj_gt_rel(not needed)
        # print("COSINE: pred_traj_fake_rel.shape", pred_traj_fake_rel.shape)
        # print("pred_traj_fake_rel", pred_traj_fake_rel) # just use this tensor for now, then you can add other stuff later.
        #pred_traj_fake_rel_shifted = torch.roll(pred_traj_fake_rel, 1, [1])
        pred_traj_fake_rel_shifted = torch.cat(
            (pred_traj_fake_rel[1:, :, :], pred_traj_fake_rel[:1, :, :]),
            dim=0)  # simulating rolling
        # print("pred_traj_fake_rel_shifted", pred_traj_fake_rel_shifted)
        cosfunc = nn.CosineSimilarity(dim=2, eps=1e-6)
        similarity = cosfunc(pred_traj_fake_rel, pred_traj_fake_rel_shifted)
        # print("similarity", similarity)
        cosine_loss = torch.mean(similarity[1:, :])
        losses['G_cosine_loss'] = cosine_loss.item()
        print("cosine_loss", cosine_loss)
        loss += -args.cosine_loss_weight * cosine_loss
        # sys.exit()
    curv_debug = False
    print("loss before adding curvature", loss)
    if args.curvature_loss_weight > 0:
        print("CURVATURE: pred_traj_fake_rel.shape", pred_traj_fake_rel.shape)
        # print("pred_traj_fake_rel", pred_traj_fake_rel) # just use this tensor for now, then you can add other stuff later.
        dists = torch.norm(pred_traj_fake_rel, p=2, dim=2)
        # print("dists", dists)
        print("dists.shape", dists.shape)

        # get 3 sets of points aas, bbs and ccs corresponding to 3 subsequent points
        aas = torch.cat((obs_traj[-2:, :, :], pred_traj_fake[:-2, :, :]),
                        dim=0)
        print("aas.dtype", aas.dtype)
        bbs = torch.cat((obs_traj[-1:, :, :], pred_traj_fake[:-1, :, :]),
                        dim=0)
        ccs = pred_traj_fake
        # print("obs_traj[-3:,:3,:]", obs_traj[-3:,:3,:])
        if (curv_debug):
            print("aas.shape", aas.shape)
            print("bbs.shape", bbs.shape)
            print("ccs.shape", ccs.shape)
            print("aas:\n", aas)
            print("bbs:\n", bbs)
            print("ccs:\n", ccs)
        cside = torch.norm(aas - bbs, p=2, dim=2)
        bside = torch.norm(aas - ccs, p=2, dim=2)
        aside = torch.norm(ccs - bbs, p=2, dim=2)
        # print("cside.shape", cside.shape)
        s = (aside + bside + cside) / 2
        areas = torch.sqrt(s * (s - aside) * (s - bside) * (s - cside))

        if curv_debug:
            onne = torch.ones(areas.shape).cuda()
            zeero = torch.zeros(areas.shape).cuda()
            print("onne.shape", onne.shape)
            print("zeero.shape", zeero.shape)
            debug_ars = torch.where(areas < 0.00001, onne, zeero)
            areas = torch.where(areas < 0.0001, zeero, areas)
            debug_ars_sum = torch.sum(debug_ars)
            print("debug_ars_sum: ", debug_ars_sum)
        # print("areas : ", areas)
        curvatures = areas / aside / bside / cside
        exp = (curvatures != curvatures)
        curvatures[exp] = 0
        # curvatures = curvatures * exp
        # print("curvatures.shape", curvatures.shape)
        # print("aside", aside)
        # print("bside", bside)
        # print("cside", cside)

        # print("areas", areas)
        # print("curvatures", curvatures)
        # curvatures[curvatures != curvatures] = 0
        curvature_loss = torch.mean(curvatures)
        # curvature_loss[curvature_lFoss != curvature_loss] = 0 #setting nan as zero
        print("curvature_loss", curvature_loss)
        losses['G_curvature_loss'] = curvature_loss.item()
        loss += args.curvature_loss_weight * curvature_loss
        # sys.exit()

    if (loss != loss):
        print("[id]", "is where you see NaN first")
        print("loss", loss)
        print("curvature_loss Nan", curvature_loss)
        print("\n\n==================NAN================\n\n")
        nan_loc = 0
        nan_id = 7
        print("\n=====  curvatures[nan_loc,nan_id:nan_id+5]\n",
              curvatures[nan_loc, nan_id:nan_id + 5])
        print("\n=====  areas[nan_loc,nan_id:nan_id+5]\n",
              areas[nan_loc, nan_id:nan_id + 5])
        print("\n=====  aside[nan_loc,nan_id:nan_id+5]\n",
              aside[nan_loc, nan_id:nan_id + 5])
        print("\n=====  bside[nan_loc,nan_id:nan_id+5]\n",
              bside[nan_loc, nan_id:nan_id + 5])
        print("\n=====  cside[nan_loc,nan_id:nan_id+5]\n",
              cside[nan_loc, nan_id:nan_id + 5])

        print("aas.shape", aas.shape)
        print("bbs.shape", bbs.shape)
        print("ccs.shape", ccs.shape)
        print("aas:\n", aas[nan_loc, nan_id:nan_id + 5])
        print("bbs:\n", bbs[nan_loc, nan_id:nan_id + 5])
        print("ccs:\n", ccs[nan_loc, nan_id:nan_id + 5])

        exp = (curvatures == curvatures)
        print("obs_traj", obs_traj[:, 6:11, :])
        print("pred_traj_fake", pred_traj_fake[:, 6:11, :])
        print("obs_traj.shape", obs_traj.shape)
        print("pred_traj_fake.shape", pred_traj_fake.shape)
        print("exp", exp[:, 6:11])
        # print ("exp", exp[nan_loc,nan_id :nan_id+5])
        # print("prevgrads", prevgrads)
        # print("generator weights: ", [x.data for x in generator.decoder.parameters()])
        # print("generator grad vals: ", [x.grad for x in generator.decoder.parameters()])
        sys.exit()

    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    discriminator_loss = g_loss_fn(scores_fake)

    if args.use_discriminator:
        print("using discriminator loss")
        loss += discriminator_loss
    losses['G_discriminator_loss'] = discriminator_loss.item()
    losses['G_total_loss'] = loss.item()

    optimizer_g.zero_grad()
    if curv_debug:
        print("generator grad vals BEFORE: ",
              [x.grad for x in generator.decoder.parameters()])
    loss.backward(retain_graph=True)

    for x in generator.decoder.parameters():
        nanmask = (x.grad != x.grad)
        # print(nanmask)
        x.grad.data[nanmask] = 0
        if torch.isnan(x.grad).any():
            print("NaN generated during backprop")
            print(x.grad)
            sys.exit()

    if curv_debug:
        print("generator grad vals: ",
              [x.grad for x in generator.decoder.parameters()])
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(generator.parameters(),
                                 args.clipping_threshold_g)

    optimizer_g.step()
    #print some grads here to see what's going wrong:
    if curv_debug:
        print("autograd.grad(loss, curvature_loss)",
              autograd.grad(loss, curvature_loss))
        # print("autograd.grad(loss, areas)", autograd.grad(loss, areas))
        # print("autograd.grad(loss, aside)", autograd.grad(loss, aside))
        # print("autograd.grad(curvature_loss, ccs)", autograd.grad(curvature_loss, ccs))
        print("autograd.grad(curvature_loss, pred_traj_fake_rel)",
              autograd.grad(curvature_loss, pred_traj_fake_rel))

    # sys.exit()
    return losses
示例#9
0
def generator_step(
    args, batch, generator, discriminator, g_loss_fn, optimizer_g
):
    if(args.use_gpu):
        batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []

    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)

        pred_traj_fake_rel = generator_out
        pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel,
                pred_traj_gt_rel,
                loss_mask,
                mode='raw'))

    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(
                loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += g_l2_loss_sum_rel

    if args.cosine_loss_weight > 0: # TODO : what to do when curvature loss needs to be included?
        # need to obtain the curvature for all data points in the batch. And save their average to curvature_loss tensor..
        # given pred_traj_fake_rel, and last two inputs(may be not needed for first iteration). pred_traj_gt_rel(not needed)
        # print("COSINE: pred_traj_fake_rel.shape", pred_traj_fake_rel.shape)
        # print("pred_traj_fake_rel", pred_traj_fake_rel) # just use this tensor for now, then you can add other stuff later.
        # pred_traj_fake_rel_shifted = torch.roll(pred_traj_fake_rel, 1, [1])
        pred_traj_fake_rel_shifted = torch.cat( (pred_traj_fake_rel[1:, :, :],  pred_traj_fake_rel[:1, :, :]), dim=0) # simulating rolling
        # print("pred_traj_fake_rel_shifted", pred_traj_fake_rel_shifted)
        cosfunc = nn.CosineSimilarity(dim=2, eps=1e-6)
        similarity = cosfunc(pred_traj_fake_rel, pred_traj_fake_rel_shifted)
        # print("similarity", similarity)
        cosine_loss = torch.sum(similarity[1:,:])
        print("cosine_loss", cosine_loss)
        loss += -args.cosine_loss_weight * cosine_loss
        # sys.exit()

    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    discriminator_loss = g_loss_fn(scores_fake)

    #loss += discriminator_loss
    losses['G_discriminator_loss'] = discriminator_loss.item()
    losses['G_total_loss'] = loss.item()

    optimizer_g.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(
            generator.parameters(), args.clipping_threshold_g
        )
    optimizer_g.step()

    return losses