Beispiel #1
0
def transform_batch(src, trg):
    trg_y = trg.clone()
    trg = torch.cat((trg, torch.zeros((trg.shape[0], trg.shape[1], 1))), 2)
    start_seq = torch.zeros((trg.shape[0], 1, trg.shape[-1]))
    start_seq[:, :, -1] = 1
    trg = torch.cat((start_seq, trg[:, :-1, :]), 1)
    src_mask = torch.ones((src.shape[0], 1, src.shape[1]))
    trg_mask = subsequent_mask(trg.shape[1]).repeat((trg.shape[0], 1, 1))

    return src, src_mask, trg, trg_mask, trg_y
Beispiel #2
0
    def forward(self, src, noise):
        """
        Given a src trajectory in shape ((b)atch, self.src_len, (d)iminsionality)
        Generate a tgt trajectory in shape ((b)atch, self.tgt_len, (d)iminsionality)
        """
        batch_size = src.shape[0]
        src_mask = torch.ones((batch_size, 1, self.src_len)).to(self.device)
        dec_inp = noise

        tgt_mask = subsequent_mask(dec_inp.shape[1]).repeat(batch_size, 1,
                                                            1).to(self.device)
        out = self.generator.generator(
            self.generator(src, dec_inp, src_mask, tgt_mask))

        return out
Beispiel #3
0
    def forward(self, src, noise):
        """
        Given a src trajectory in shape ((b)atch, self.src_len, (d)iminsionality)
        Generate a tgt trajectory in shape ((b)atch, self.tgt_len, (d)iminsionality)
        """
        batch_size = src.shape[0]
        src_mask = torch.ones((batch_size, 1, self.src_len)).to(self.device)
        dec_inp = noise

        # Now generate step by step
        for i in range(self.tgt_len):
            tgt_mask = subsequent_mask(dec_inp.shape[1]).repeat(
                batch_size, 1, 1).to(self.device)
            out = self.generator.generator(
                self.generator(src, dec_inp, src_mask, tgt_mask))
            dec_inp = torch.cat((dec_inp, out[:, -1:, :]), 1)

        return dec_inp[:, 1:, :]  # skip the start of sequence
def trajectory(inp):
    device = torch.device("cpu")
    import individual_TF
    model = individual_TF.IndividualTF(2,
                                       3,
                                       3,
                                       N=6,
                                       d_model=512,
                                       d_ff=2048,
                                       h=8,
                                       dropout=0.1,
                                       mean=[0, 0],
                                       std=[0, 0]).to(device)

    model.load_state_dict(
        torch.load(f'models/Individual/my_data_train/00099.pth'))
    model.eval()
    gt = []
    pr = []
    inp_ = []
    peds = []
    frames = []
    dt = []
    inp = np.array(inp, dtype=np.float32)
    inp = torch.from_numpy(inp)
    inp = inp.to(device)
    src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
    start_of_seq = torch.Tensor([0, 0, 1]).unsqueeze(0).unsqueeze(1).repeat(
        inp.shape[0], 1, 1).to(device)
    dec_inp = start_of_seq

    for i in range(12):
        trg_att = subsequent_mask(dec_inp.shape[1]).repeat(
            dec_inp.shape[0], 1, 1).to(device)
        out = model(inp, dec_inp, src_att, trg_att)
        dec_inp = torch.cat((dec_inp, out[:, -1:, :]), 1)

    preds_tr_b = (dec_inp[:, 1:, 0:2]).cpu().detach().numpy(
    )  #.cumsum(1)#+batch['src'][:,-1:,0:2].cpu().detach().numpy()
    pr.append(preds_tr_b)
    pr = np.concatenate(pr, 0)
    #print(pr)
    return pr
def main():
    parser=argparse.ArgumentParser(description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder',type=str,default='datasets')
    parser.add_argument('--dataset_name',type=str,default='zara1')
    parser.add_argument('--obs',type=int,default=8)
    parser.add_argument('--preds',type=int,default=12)
    parser.add_argument('--emb_size',type=int,default=512)
    parser.add_argument('--heads',type=int, default=8)
    parser.add_argument('--layers',type=int,default=6)
    parser.add_argument('--cpu',action='store_true')
    parser.add_argument('--verbose',action='store_true')
    parser.add_argument('--batch_size',type=int,default=256)
    parser.add_argument('--delim',type=str,default='\t')
    parser.add_argument('--name', type=str, default="zara1")
    parser.add_argument('--epoch',type=str,default="00001")
    parser.add_argument('--num_samples', type=int, default="20")




    args=parser.parse_args()
    model_name=args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/QuantizedTF')
    except:
        pass
    try:
        os.mkdir(f'models/QuantizedTF')
    except:
        pass

    try:
        os.mkdir(f'output/QuantizedTF/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/QuantizedTF/{args.name}')
    except:
        pass

    #log=SummaryWriter('logs/%s'%model_name)

    # log.add_scalar('eval/mad', 0, 0)
    # log.add_scalar('eval/fad', 0, 0)
    device=torch.device("cuda")

    if args.cpu or not torch.cuda.is_available():
        device=torch.device("cpu")

    args.verbose=True


    ## creation of the dataloaders for train and validation

    test_dataset,_ =  baselineUtils.create_dataset(args.dataset_folder,args.dataset_name,0,args.obs,args.preds,delim=args.delim,train=False,eval=True,verbose=args.verbose)

    mat = scipy.io.loadmat(os.path.join(args.dataset_folder, args.dataset_name, "clusters.mat"))

    clusters=mat['centroids']

    model=quantized_TF.QuantizedTF(clusters.shape[0], clusters.shape[0]+1, clusters.shape[0], N=args.layers,
                   d_model=args.emb_size, d_ff=1024, h=args.heads).to(device)

    model.load_state_dict(torch.load(f'models/QuantizedTF/{args.name}/{args.epoch}.pth'))
    model.to(device)


    test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

    #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)




    # DETERMINISTIC MODE
    with torch.no_grad():
        model.eval()
        gt=[]
        pr=[]
        inp_=[]
        peds=[]
        frames=[]
        dt=[]
        for id_b,batch in enumerate(test_dl):
            print(f"batch {id_b:03d}/{len(test_dl)}")
            peds.append(batch['peds'])
            frames.append(batch['frames'])
            dt.append(batch['dataset'])
            scale = np.random.uniform(0.5, 2)
            # rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])
            n_in_batch = batch['src'].shape[0]
            speeds_inp = batch['src'][:, 1:, 2:4]
            gt_b = batch['trg'][:, :, 0:2]
            inp = torch.tensor(
                scipy.spatial.distance.cdist(speeds_inp.reshape(-1, 2), clusters).argmin(axis=1).reshape(n_in_batch,
                                                                                                         -1)).to(
                device)
            src_att = torch.ones((inp.shape[0], 1,inp.shape[1])).to(device)
            start_of_seq = torch.tensor([clusters.shape[0]]).repeat(n_in_batch).unsqueeze(1).to(device)
            dec_inp = start_of_seq

            for i in range(args.preds):
                trg_att = subsequent_mask(dec_inp.shape[1]).repeat(n_in_batch, 1, 1).to(device)
                out = model(inp, dec_inp, src_att, trg_att)
                dec_inp=torch.cat((dec_inp,out[:,-1:].argmax(dim=2)),1)


            preds_tr_b=clusters[dec_inp[:,1:].cpu().numpy()].cumsum(1)+batch['src'][:,-1:,0:2].cpu().numpy()
            gt.append(gt_b)
            pr.append(preds_tr_b)

        peds=np.concatenate(peds,0)
        frames=np.concatenate(frames,0)
        dt=np.concatenate(dt,0)
        gt=np.concatenate(gt,0)
        dt_names=test_dataset.data['dataset_name']
        pr=np.concatenate(pr,0)
        mad,fad,errs=baselineUtils.distance_metrics(gt,pr)

        #log.add_scalar('eval/DET_mad', mad, epoch)
        #log.add_scalar('eval/DET_fad', fad, epoch)

        scipy.io.savemat(f"output/QuantizedTF/{args.name}/MM_deterministic.mat",{'input':inp,'gt':gt,'pr':pr,'peds':peds,'frames':frames,'dt':dt,'dt_names':dt_names})

        print("Determinitic:")
        print("mad: %6.3f"%mad)
        print("fad: %6.3f" % fad)


        # MULTI MODALITY
        num_samples=args.num_samples

        model.eval()
        gt=[]
        pr_all={}
        inp_=[]
        peds=[]
        frames=[]
        dt=[]
        for sam in range(num_samples):
            pr_all[sam]=[]
        for id_b,batch in enumerate(test_dl):
            print(f"batch {id_b:03d}/{len(test_dl)}")
            peds.append(batch['peds'])
            frames.append(batch['frames'])
            dt.append(batch['dataset'])
            scale = np.random.uniform(0.5, 2)
            # rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])
            n_in_batch = batch['src'].shape[0]
            speeds_inp = batch['src'][:, 1:, 2:4]
            gt_b = batch['trg'][:, :, 0:2]
            gt.append(gt_b)
            inp__=batch['src'][:,:,0:2]
            inp_.append(inp__)
            inp = torch.tensor(
                scipy.spatial.distance.cdist(speeds_inp.reshape(-1, 2), clusters).argmin(axis=1).reshape(n_in_batch,
                                                                                                         -1)).to(
                device)
            src_att = torch.ones((inp.shape[0], 1,inp.shape[1])).to(device)
            start_of_seq = torch.tensor([clusters.shape[0]]).repeat(n_in_batch).unsqueeze(1).to(device)

            for sam in range(num_samples):
                dec_inp = start_of_seq

                for i in range(args.preds):
                    trg_att = subsequent_mask(dec_inp.shape[1]).repeat(n_in_batch, 1, 1).to(device)
                    out = model.predict(inp, dec_inp, src_att, trg_att)
                    h=out[:,-1]
                    dec_inp=torch.cat((dec_inp,torch.multinomial(h,1)),1)


                preds_tr_b=clusters[dec_inp[:,1:].cpu().numpy()].cumsum(1)+batch['src'][:,-1:,0:2].cpu().numpy()

                pr_all[sam].append(preds_tr_b)
        peds=np.concatenate(peds,0)
        frames=np.concatenate(frames,0)
        dt=np.concatenate(dt,0)
        gt=np.concatenate(gt,0)
        dt_names=test_dataset.data['dataset_name']
        #pr=np.concatenate(pr,0)
        inp=np.concatenate(inp_,0)
        samp = {}
        for k in pr_all.keys():
            samp[k] = {}
            samp[k]['pr'] = np.concatenate(pr_all[k], 0)
            samp[k]['mad'], samp[k]['fad'], samp[k]['err'] = baselineUtils.distance_metrics(gt, samp[k]['pr'])

        ev = [samp[i]['err'] for i in range(num_samples)]
        e20 = np.stack(ev, -1)
        mad_samp=e20.mean(1).min(-1).mean()
        fad_samp=e20[:,-1].min(-1).mean()
        #mad,fad,errs=baselineUtils.distance_metrics(gt,pr)

        #log.add_scalar('eval/MM_mad', mad_samp, epoch)
        #log.add_scalar('eval/MM_fad', fad_samp, epoch)
        preds_all_fin=np.stack(list([samp[i]['pr'] for i in range(num_samples)]),-1)
        scipy.io.savemat(f"output/QuantizedTF/{args.name}/MM_{num_samples}.mat",{'input':inp,'gt':gt,'pr':preds_all_fin,'peds':peds,'frames':frames,'dt':dt,'dt_names':dt_names})

        print("Determinitic:")
        print("mad: %6.3f"%mad)
        print("fad: %6.3f" % fad)

        print("Multimodality:")
        print("mad: %6.3f"%mad_samp)
        print("fad: %6.3f" % fad_samp)
def main():
    parser = argparse.ArgumentParser(
        description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder', type=str, default='datasets')
    parser.add_argument('--dataset_name', type=str, default='zara1')
    parser.add_argument('--obs', type=int, default=8)
    parser.add_argument('--preds', type=int, default=12)
    parser.add_argument('--emb_size', type=int, default=512)
    parser.add_argument('--heads', type=int, default=8)
    parser.add_argument('--layers', type=int, default=6)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--val_size', type=int, default=0)
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--max_epoch', type=int, default=1500)
    parser.add_argument('--batch_size', type=int, default=70)
    parser.add_argument('--validation_epoch_start', type=int, default=30)
    parser.add_argument('--resume_train', action='store_true')
    parser.add_argument('--delim', type=str, default='\t')
    parser.add_argument('--name', type=str, default="zara1")
    parser.add_argument('--factor', type=float, default=1.)
    parser.add_argument('--save_step', type=int, default=1)
    parser.add_argument('--warmup', type=int, default=10)
    parser.add_argument('--evaluate', type=bool, default=True)

    args = parser.parse_args()
    model_name = args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/Individual')
    except:
        pass
    try:
        os.mkdir(f'models/Individual')
    except:
        pass

    try:
        os.mkdir(f'output/Individual/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/Individual/{args.name}')
    except:
        pass

    log = SummaryWriter('logs/Ind_%s' % model_name)

    log.add_scalar('eval/mad', 0, 0)
    log.add_scalar('eval/fad', 0, 0)
    device = torch.device("cuda")

    if args.cpu or not torch.cuda.is_available():
        device = torch.device("cpu")

    args.verbose = True

    ## creation of the dataloaders for train and validation
    if args.val_size == 0:
        train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                        args.dataset_name,
                                                        0,
                                                        args.obs,
                                                        args.preds,
                                                        delim=args.delim,
                                                        train=True,
                                                        verbose=args.verbose)
        val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                      args.dataset_name,
                                                      0,
                                                      args.obs,
                                                      args.preds,
                                                      delim=args.delim,
                                                      train=False,
                                                      verbose=args.verbose)
    else:
        train_dataset, val_dataset = baselineUtils.create_dataset(
            args.dataset_folder,
            args.dataset_name,
            args.val_size,
            args.obs,
            args.preds,
            delim=args.delim,
            train=True,
            verbose=args.verbose)

    test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                   args.dataset_name,
                                                   0,
                                                   args.obs,
                                                   args.preds,
                                                   delim=args.delim,
                                                   train=False,
                                                   eval=True,
                                                   verbose=args.verbose)

    import individual_TF
    model = individual_TF.IndividualTF(2,
                                       3,
                                       3,
                                       N=args.layers,
                                       d_model=args.emb_size,
                                       d_ff=2048,
                                       h=args.heads,
                                       dropout=args.dropout,
                                       mean=[0, 0],
                                       std=[0, 0]).to(device)

    tr_dl = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=0)
    val_dl = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=0)
    test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=0)

    #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)
    optim = NoamOpt(
        args.emb_size, args.factor,
        len(tr_dl) * args.warmup,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))
    #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001)
    epoch = 0

    #mean=train_dataset[:]['src'][:,1:,2:4].mean((0,1))
    mean = torch.cat((train_dataset[:]['src'][:, 1:, 2:4],
                      train_dataset[:]['trg'][:, :, 2:4]), 1).mean((0, 1))
    #std=train_dataset[:]['src'][:,1:,2:4].std((0,1))
    std = torch.cat((train_dataset[:]['src'][:, 1:, 2:4],
                     train_dataset[:]['trg'][:, :, 2:4]), 1).std((0, 1))
    means = []
    stds = []
    for i in np.unique(train_dataset[:]['dataset']):
        ind = train_dataset[:]['dataset'] == i
        means.append(
            torch.cat((train_dataset[:]['src'][ind, 1:, 2:4],
                       train_dataset[:]['trg'][ind, :, 2:4]), 1).mean((0, 1)))
        stds.append(
            torch.cat((train_dataset[:]['src'][ind, 1:, 2:4],
                       train_dataset[:]['trg'][ind, :, 2:4]), 1).std((0, 1)))
    mean = torch.stack(means).mean(0)
    std = torch.stack(stds).mean(0)

    scipy.io.savemat(f'models/Individual/{args.name}/norm.mat', {
        'mean': mean.cpu().numpy(),
        'std': std.cpu().numpy()
    })

    while epoch < args.max_epoch:
        epoch_loss = 0
        model.train()

        for id_b, batch in enumerate(tr_dl):

            optim.optimizer.zero_grad()  #将所有variable的grad设置为0
            inp = (batch['src'][:, 1:, 2:4].to(device) -
                   mean.to(device)) / std.to(device)
            target = (batch['trg'][:, :-1, 2:4].to(device) -
                      mean.to(device)) / std.to(device)
            target_c = torch.zeros(
                (target.shape[0], target.shape[1], 1)).to(device)
            target = torch.cat((target, target_c), -1)
            start_of_seq = torch.Tensor([0, 0,
                                         1]).unsqueeze(0).unsqueeze(1).repeat(
                                             target.shape[0], 1, 1).to(device)

            dec_inp = torch.cat((start_of_seq, target), 1)

            src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
            trg_att = subsequent_mask(dec_inp.shape[1]).repeat(
                dec_inp.shape[0], 1, 1).to(device)

            pred = model(inp, dec_inp, src_att, trg_att)

            loss = F.pairwise_distance(
                pred[:, :, 0:2].contiguous().view(-1, 2),
                ((batch['trg'][:, :, 2:4].to(device) - mean.to(device)) /
                 std.to(device)).contiguous().view(
                     -1, 2).to(device)).mean() + torch.mean(
                         torch.abs(pred[:, :, 2]))
            loss.backward()
            optim.step()  # 更新variable的grad
            print("train epoch %03i/%03i  batch %04i / %04i loss: %7.4f" %
                  (epoch, args.max_epoch, id_b, len(tr_dl), loss.item()))
            epoch_loss += loss.item()
        #sched.step()
        log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch)
        with torch.no_grad():
            model.eval()  # 不更新梯度和batchnormalize

            val_loss = 0
            step = 0
            model.eval()
            gt = []
            pr = []
            inp_ = []
            peds = []
            frames = []
            dt = []

            for id_b, batch in enumerate(val_dl):
                inp_.append(batch['src'])
                gt.append(batch['trg'][:, :, 0:2])
                frames.append(batch['frames'])
                peds.append(batch['peds'])
                dt.append(batch['dataset'])

                inp = (batch['src'][:, 1:, 2:4].to(device) -
                       mean.to(device)) / std.to(device)
                src_att = torch.ones(
                    (inp.shape[0], 1, inp.shape[1])).to(device)
                start_of_seq = torch.Tensor(
                    [0, 0,
                     1]).unsqueeze(0).unsqueeze(1).repeat(inp.shape[0], 1,
                                                          1).to(device)
                dec_inp = start_of_seq

                for i in range(args.preds):
                    trg_att = subsequent_mask(dec_inp.shape[1]).repeat(
                        dec_inp.shape[0], 1, 1).to(device)
                    out = model(inp, dec_inp, src_att, trg_att)
                    dec_inp = torch.cat((dec_inp, out[:, -1:, :]), 1)

                preds_tr_b = (dec_inp[:, 1:, 0:2] * std.to(device) +
                              mean.to(device)).cpu().numpy().cumsum(
                                  1) + batch['src'][:, -1:, 0:2].cpu().numpy()
                pr.append(preds_tr_b)
                print("val epoch %03i/%03i  batch %04i / %04i" %
                      (epoch, args.max_epoch, id_b, len(val_dl)))

            peds = np.concatenate(peds, 0)
            frames = np.concatenate(frames, 0)
            dt = np.concatenate(dt, 0)
            gt = np.concatenate(gt, 0)
            dt_names = test_dataset.data['dataset_name']
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)
            log.add_scalar('validation/MAD', mad, epoch)
            log.add_scalar('validation/FAD', fad, epoch)

            if args.evaluate:

                model.eval()
                gt = []
                pr = []
                inp_ = []
                peds = []
                frames = []
                dt = []

                for id_b, batch in enumerate(test_dl):
                    inp_.append(batch['src'])
                    gt.append(batch['trg'][:, :, 0:2])
                    frames.append(batch['frames'])
                    peds.append(batch['peds'])
                    dt.append(batch['dataset'])

                    inp = (batch['src'][:, 1:, 2:4].to(device) -
                           mean.to(device)) / std.to(device)
                    src_att = torch.ones(
                        (inp.shape[0], 1, inp.shape[1])).to(device)
                    start_of_seq = torch.Tensor([
                        0, 0, 1
                    ]).unsqueeze(0).unsqueeze(1).repeat(inp.shape[0], 1,
                                                        1).to(device)
                    dec_inp = start_of_seq

                    for i in range(args.preds):
                        trg_att = subsequent_mask(dec_inp.shape[1]).repeat(
                            dec_inp.shape[0], 1, 1).to(device)
                        out = model(inp, dec_inp, src_att, trg_att)
                        dec_inp = torch.cat((dec_inp, out[:, -1:, :]), 1)

                    preds_tr_b = (dec_inp[:, 1:, 0:2] * std.to(device) +
                                  mean.to(device)).cpu().numpy().cumsum(
                                      1) + batch['src'][:, -1:,
                                                        0:2].cpu().numpy()
                    pr.append(preds_tr_b)
                    print("test epoch %03i/%03i  batch %04i / %04i" %
                          (epoch, args.max_epoch, id_b, len(test_dl)))

                peds = np.concatenate(peds, 0)
                frames = np.concatenate(frames, 0)
                dt = np.concatenate(dt, 0)
                gt = np.concatenate(gt, 0)
                dt_names = test_dataset.data['dataset_name']
                pr = np.concatenate(pr, 0)
                mad, fad, errs = baselineUtils.distance_metrics(gt, pr)

                log.add_scalar('eval/DET_mad', mad, epoch)
                log.add_scalar('eval/DET_fad', fad, epoch)

                # log.add_scalar('eval/DET_mad', mad, epoch)
                # log.add_scalar('eval/DET_fad', fad, epoch)

                scipy.io.savemat(
                    f"output/Individual/{args.name}/det_{epoch}.mat", {
                        'input': inp,
                        'gt': gt,
                        'pr': pr,
                        'peds': peds,
                        'frames': frames,
                        'dt': dt,
                        'dt_names': dt_names
                    })

        if epoch % args.save_step == 0:

            torch.save(model.state_dict(),
                       f'models/Individual/{args.name}/{epoch:05d}.pth')

        epoch += 1
    ab = 1
def main():
    parser = argparse.ArgumentParser(
        description='Train the individual Transformer model | 训练独立Transformer模型'
    )
    parser.add_argument('--dataset_folder', type=str, default='datasets')
    parser.add_argument('--dataset_name', type=str, default='radar')
    parser.add_argument('--obs', type=int, default=20)  # TODO 观测点,单条训练数据的长度
    parser.add_argument('--preds', type=int,
                        default=20)  # TODO 预测点数,要设置为数据集的大小
    parser.add_argument('--emb_size', type=int, default=512)  # 编码层大小
    parser.add_argument('--heads', type=int, default=8)  # TODO
    parser.add_argument('--layers', type=int, default=6)  # TODO
    parser.add_argument('--dropout', type=float, default=0.1)  # DropOut 默认值
    parser.add_argument('--cpu', action='store_true')  # 设置是否使用CPU
    parser.add_argument('--val_size', type=int, default=0)  # 验证集的尺寸
    parser.add_argument('--verbose', action='store_true')  # 是否详细输出日志
    parser.add_argument('--max_epoch', type=int, default=500)  # TODO 最大训练次代
    parser.add_argument('--batch_size', type=int, default=70)  # 批大小
    parser.add_argument('--validation_epoch_start', type=int,
                        default=30)  # TODO 可能是验证次代起始位置?
    parser.add_argument('--resume_train', action='store_true')  # TODO 继续训练?
    parser.add_argument('--delim', type=str, default='\t')  # 标识数据集的分隔符
    parser.add_argument('--name', type=str, default="radar")  # 模型名称
    parser.add_argument('--factor', type=float, default=1.)  # TODO
    parser.add_argument('--save_step', type=int, default=1)  # TODO
    parser.add_argument('--warmup', type=int, default=10)  # TODO 开始热身的批次
    parser.add_argument('--evaluate', type=bool, default=True)  # 是否对数据集进行评估

    args = parser.parse_args()
    model_name = args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/Individual')
    except:
        pass
    try:
        os.mkdir(f'models/Individual')
    except:
        pass

    try:
        os.mkdir(f'output/Individual/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/Individual/{args.name}')
    except:
        pass

    log = SummaryWriter('logs/Ind_%s' % model_name)

    log.add_scalar('eval/mad', 0, 0)
    log.add_scalar('eval/fad', 0, 0)
    device = torch.device("cuda")

    # 当设置为使用CPU并且GPU不可用的时候才会使用CPU进行训练
    if args.cpu or not torch.cuda.is_available():
        print("\033[1;31;40m Training with cpu... \033[0m")
        device = torch.device("cpu")

    args.verbose = True

    ## creation of the dataloaders for train and validation
    # 创建训练集和验证集的 dataloader
    train_dataset, _ = baselineUtils.create_old_3dim_dataset(
        args.dataset_folder,
        args.dataset_name,
        0,
        args.obs,
        args.preds,
        train=True,
        verbose=args.verbose)

    import individual_TF_3D
    model = individual_TF_3D.IndividualTF(3,
                                          4,
                                          4,
                                          N=args.layers,
                                          d_model=args.emb_size,
                                          d_ff=2048,
                                          h=args.heads,
                                          dropout=args.dropout,
                                          mean=[0, 0, 0],
                                          std=[0, 0, 0]).to(device)

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0)

    # optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    # sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)
    # 使用了自己实现的 Optimizer
    # FIXME betas 参数是否需要继续调整
    optim = NoamOpt(
        args.emb_size, args.factor,
        len(train_dataloader) * args.warmup,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))
    # optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001)
    epoch = 0

    # mean=train_dataset[:]['src'][:,1:,2:4].mean((0,1))
    # mean = torch.cat((train_dataset[:]['src'][:, 1:, 2:4], train_dataset[:]['trg'][:, :, 2:4]), 1).mean((0, 1))
    # std=train_dataset[:]['src'][:,1:,2:4].std((0,1))
    # std = torch.cat((train_dataset[:]['src'][:, 1:, 2:4], train_dataset[:]['trg'][:, :, 2:4]), 1).std((0, 1))
    # means = []
    # stds = []
    # for i in np.unique(train_dataset[:]['dataset']):
    #     ind = train_dataset[:]['dataset'] == i
    #     means.append(
    #         torch.cat((train_dataset[:]['src'][ind, 1:, 2:4], train_dataset[:]['trg'][ind, :, 2:4]), 1).mean((0, 1)))
    #     stds.append(
    #         torch.cat((train_dataset[:]['src'][ind, 1:, 2:4], train_dataset[:]['trg'][ind, :, 2:4]), 1).std((0, 1)))
    mean = torch.cat((train_dataset[:]['src'][:, 1:, 3:6],
                      train_dataset[:]['trg'][:, :, 3:6]), 1).mean((0, 1))
    std = torch.cat((train_dataset[:]['src'][:, 1:, 3:6],
                     train_dataset[:]['trg'][:, :, 3:6]), 1).std((0, 1))

    scipy.io.savemat(f'models/Individual/{args.name}/norm.mat', {
        'mean': mean.cpu().numpy(),
        'std': std.cpu().numpy()
    })

    model.train()

    train_batch_bar = tqdm([(id, batch)
                            for id, batch in enumerate(train_dataloader)][0:1])

    prediction = None

    for id_b, batch in train_batch_bar:
        optim.optimizer.zero_grad()
        # (batch_size, 19, 3)
        inp = (batch['src'][:, 1:, 3:6].to(device) -
               mean.to(device)) / std.to(device)
        # (batch_size, 11, 2)
        target = (batch['trg'][:, :-1, 3:6].to(device) -
                  mean.to(device)) / std.to(device)
        target_c = torch.zeros(
            (target.shape[0], target.shape[1], 1)).to(device)
        # 第三维合并
        target = torch.cat((target, target_c), -1)
        start_of_seq = torch.Tensor([0, 0, 0,
                                     1]).unsqueeze(0).unsqueeze(1).repeat(
                                         target.shape[0], 1, 1).to(device)

        # (batch_size, 20, 4)
        decoder_input = torch.cat((start_of_seq, target), 1)

        # (input_shape[0], 1, input_shape[1] | 19)
        src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
        # (batch_size, 20, 20)
        trg_att = subsequent_mask(decoder_input.shape[1]).repeat(
            decoder_input.shape[0], 1, 1).to(device)

        # (batch_size, 20, 4)
        prediction = model(inp, decoder_input, src_att, trg_att)

        # 计算两个矩阵的成对距离
        loss = F.pairwise_distance(
            prediction[:, :, 0:3].contiguous().view(-1, 2),
            ((batch['trg'][:, :, 3:6].to(device) - mean.to(device)) /
             std.to(device)).contiguous().view(
                 -1, 2).to(device)).mean() + torch.mean(
                     torch.abs(prediction[:, :, 3]))
        loss.backward()
        optim.step()
        train_batch_bar.set_description("loss: %7.4f" % loss.item())
        # print("train epoch %03i/%03i  batch %04i / %04i loss: %7.4f" % (
        #     epoch, args.max_epoch, id_b, len(train_dataloader), loss.item()))

    graph = baselineUtils.make_dot(prediction)
    graph.view()
    params = list(model.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))
Beispiel #8
0
def main():
    parser=argparse.ArgumentParser(description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder',type=str,default='datasets')
    parser.add_argument('--dataset_name',type=str,default='zara1')
    parser.add_argument('--obs',type=int,default=8)
    parser.add_argument('--preds',type=int,default=12)
    parser.add_argument('--emb_size',type=int,default=512)
    parser.add_argument('--heads',type=int, default=8)
    parser.add_argument('--layers',type=int,default=6)
    parser.add_argument('--dropout',type=float,default=0.1)
    parser.add_argument('--cpu',action='store_true')
    parser.add_argument('--val_size',type=int, default=0)
    parser.add_argument('--verbose',action='store_true')
    parser.add_argument('--max_epoch',type=int, default=1500)
    parser.add_argument('--batch_size',type=int,default=70)
    parser.add_argument('--validation_epoch_start', type=int, default=30)
    parser.add_argument('--resume_train',action='store_true')
    parser.add_argument('--delim',type=str,default='\t')
    parser.add_argument('--name', type=str, default="zara1")
    parser.add_argument('--factor', type=float, default=1.)
    parser.add_argument('--save_step', type=int, default=1)
    parser.add_argument('--warmup', type=int, default=10)
    parser.add_argument('--evaluate', type=bool, default=True)
    parser.add_argument('--model_pth', type=str)




    args=parser.parse_args()
    model_name=args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/Individual')
    except:
        pass
    try:
        os.mkdir(f'models/Individual')
    except:
        pass

    try:
        os.mkdir(f'output/Individual/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/Individual/{args.name}')
    except:
        pass

    #log=SummaryWriter('logs/Ind_%s'%model_name)

    #log.add_scalar('eval/mad', 0, 0)
    #log.add_scalar('eval/fad', 0, 0)
    device=torch.device("cuda")

    if args.cpu or not torch.cuda.is_available():
      device=torch.device("cpu")

    args.verbose=True
    if args.val_size==0:
        train_dataset,_ = baselineUtils.create_dataset(args.dataset_folder,args.dataset_name,0,args.obs,args.preds,delim=args.delim,train=True,verbose=args.verbose)
        val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs,
                                                                    args.preds, delim=args.delim, train=False,
                                                                    verbose=args.verbose)
    else:
        train_dataset, val_dataset = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, args.val_size,args.obs,
                                                              args.preds, delim=args.delim, train=True,
                                                              verbose=args.verbose)

    test_dataset,_ =  baselineUtils.create_dataset(args.dataset_folder,args.dataset_name,0,args.obs,args.preds,delim=args.delim,train=False,eval=True,verbose=args.verbose)




    import individual_TF
    model=individual_TF.IndividualTF(2, 3, 3, N=args.layers,
                   d_model=args.emb_size, d_ff=2048, h=args.heads, dropout=args.dropout,mean=[0,0],std=[0,0]).to(device)

    
    model.load_state_dict(torch.load(f'models/Individual/my_data_train/00600.pth'))
    #tr_dl = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    #val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)



    #mean=train_dataset[:]['src'][:,1:,2:4].mean((0,1))
    #mean=torch.cat((train_dataset[:]['src'][:,1:,2:4],train_dataset[:]['trg'][:,:,2:4]),1).mean((0,1))
    #std=train_dataset[:]['src'][:,1:,2:4].std((0,1))
    #std=torch.cat((train_dataset[:]['src'][:,1:,2:4],train_dataset[:]['trg'][:,:,2:4]),1).std((0,1))
    #means=[]
    #stds=[]
    #for i in np.unique(train_dataset[:]['dataset']):
    #    ind=train_dataset[:]['dataset']==i
    #    means.append(torch.cat((train_dataset[:]['src'][ind, 1:, 2:4], train_dataset[:]['trg'][ind, :, 2:4]), 1).mean((0, 1)))
    #    stds.append(
    #        torch.cat((train_dataset[:]['src'][ind, 1:, 2:4], train_dataset[:]['trg'][ind, :, 2:4]), 1).std((0, 1)))
    #mean=torch.stack(means).mean(0)
    #std=torch.stack(stds).mean(0)

    model.eval()
    gt = []
    pr = []
    inp_ = []
    peds = []
    frames = []
    dt = []
                
    for id_b,batch in enumerate(test_dl):
        inp_.append(batch['src'])
        gt.append(batch['trg'][:,:,0:2])
        frames.append(batch['frames'])
        peds.append(batch['peds'])
        dt.append(batch['dataset'])

        inp = batch['src'][:, 1:, 2:4].to(device) #- mean.to(device)) / std.to(device)
        src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
        start_of_seq = torch.Tensor([0, 0, 1]).unsqueeze(0).unsqueeze(1).repeat(inp.shape[0], 1, 1).to(
                        device)
        dec_inp=start_of_seq

        for i in range(args.preds):
            trg_att = subsequent_mask(dec_inp.shape[1]).repeat(dec_inp.shape[0], 1, 1).to(device)
            out = model(inp, dec_inp, src_att, trg_att)
            dec_inp=torch.cat((dec_inp,out[:,-1:,:]),1)


        preds_tr_b=(dec_inp[:,1:,0:2]).cpu().detach().numpy().cumsum(1)+batch['src'][:,-1:,0:2].cpu().detach().numpy()
        pr.append(preds_tr_b)
        # print("test epoch %03i/%03i  batch %04i / %04i" % (
        #         epoch, args.max_epoch, id_b, len(test_dl)))

    peds = np.concatenate(peds, 0)
    frames = np.concatenate(frames, 0)
    dt = np.concatenate(dt, 0)
    gt = np.concatenate(gt, 0)
    dt_names = test_dataset.data['dataset_name']
    pr = np.concatenate(pr, 0)
    mad, fad, errs = baselineUtils.distance_metrics(gt, pr)
    #print(frames)
    #print(dt.shape)
    #print(dt)
    #print(gt[1])
    #print(pr[1])
    print("mad %f fad %f"%(mad,fad))
    for i in range(pr.shape[0]):
        pathin = 'c_1 frames/c_1_'
        pathout = '5_1 frames_out/'
        img = cv2.imread(pathin+str(frames[i][8])+'.jpg')
        cg = (0,255,0) # green
        cp = (0,0,255) # red
        #print(gt[i])
        #print(pr[i])
        for j in range(12):
            gp = (int(gt[i,j,0]*1920),int(gt[i,j,1]*1080))
            pp = (int(pr[i,j,0]*1920),int(pr[i,j,1]*1080))
            img = cv2.circle(img,gp,3,cg,-1)
            img = cv2.circle(img,pp,3,cp,-1)
            #print(gp)
            #print(pp)
            #print(frames[i][8])
        cv2.imwrite(pathout+str(frames[i][8])+'.jpg',img)
def main():
    """
    切换数据集,需要更改 name 参数和 dataset_name 参数
    对于不同的训练模型,不想删除历史信息,可以更改 name 参数,存入新的位置
    """
    parser = argparse.ArgumentParser(description='Train the individual Transformer model | 训练独立Transformer模型')
    parser.add_argument('--dataset_folder', type=str, default='datasets')
    parser.add_argument('--dataset_name', type=str, default='radar_10_2')
    parser.add_argument('--obs', type=int, default=20)  # TODO 观测点,单条训练数据的长度
    parser.add_argument('--preds', type=int, default=10)  # TODO 预测点数,要设置为数据集的大小
    parser.add_argument('--emb_size', type=int, default=512)  # 编码层大小
    parser.add_argument('--heads', type=int, default=8)  # TODO
    parser.add_argument('--layers', type=int, default=6)  # TODO
    parser.add_argument('--dropout', type=float, default=0.1)  # DropOut 默认值
    parser.add_argument('--cpu', action='store_true')  # 设置是否使用CPU
    parser.add_argument('--val_size', type=int, default=0)  # 验证集的尺寸
    parser.add_argument('--verbose', action='store_true')  # 是否详细输出日志
    parser.add_argument('--max_epoch', type=int, default=400)  # 最大训练次代
    parser.add_argument('--batch_size', type=int, default=64)  # 批大小
    parser.add_argument('--validation_epoch_start', type=int, default=10)  # TODO 可能是验证次代起始位置?
    parser.add_argument('--resume_train', action='store_true')  # 是否继续训练
    parser.add_argument('--delim', type=str, default='\t')  # 数据集分隔符
    parser.add_argument('--name', type=str, default="radar_obs-15_pred-5")  # 模型名称
    parser.add_argument('--factor', type=float, default=1.)  # TODO
    parser.add_argument('--save_step', type=int, default=10)  # 每隔多少次保存一次模型
    parser.add_argument('--warmup', type=int, default=10)  # 热身的批次
    parser.add_argument('--evaluate', type=bool, default=True)  # 是否对数据集进行评估
    parser.add_argument('--del_hist', action='store_true', default=False)  # 是否删除历史训练信息

    args = parser.parse_args()
    model_name = args.name

    if args.del_hist:
        log_dir = 'logs/Ind_%s' % model_name
        output_dir = f'output/Individual/{args.name}'
        models_dir = f'models/Individual/{args.name}'

        if os.path.exists(log_dir):
            shutil.rmtree(log_dir, ignore_errors=True)
        if os.path.exists(output_dir):
            shutil.rmtree(output_dir, ignore_errors=True)
        if os.path.exists(models_dir):
            shutil.rmtree(models_dir, ignore_errors=True)

        assert not os.path.exists(log_dir) and not os.path.exists(log_dir) and not os.path.exists(models_dir)

        print("\033[0;32mHistory files removed. \033[0m")

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/Individual')
    except:
        pass
    try:
        os.mkdir(f'models/Individual')
    except:
        pass

    try:
        os.mkdir(f'output/Individual/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/Individual/{args.name}')
    except:
        pass

    log = SummaryWriter('logs/Ind_%s' % model_name)

    log.add_scalar('eval/mad', 0, 0)
    log.add_scalar('eval/fad', 0, 0)
    device = torch.device("cuda")

    os.environ["CUDA_VISIBLE_DEVICES"] = "2"

    if args.cpu or not torch.cuda.is_available():
        print("\033[1;31;40m Training with cpu... \033[0m")
        device = torch.device("cpu")

    args.verbose = True

    train_dataset, val_dataset, test_dataset = baselineUtils.create_new_3dim_dataset(
        args.dataset_folder, args.dataset_name, 0, gt=args.preds, horizon=args.obs,
        delim=args.delim, train=False, eval=True, verbose=args.verbose)

    import individual_TF_3D
    model = individual_TF_3D.IndividualTF(3, 4, 4, N=args.layers,
                                          d_model=args.emb_size, d_ff=2048, h=args.heads, dropout=args.dropout,
                                          mean=[0, 0, 0], std=[0, 0, 0]).to(device)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                                   num_workers=0)
    validate_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True,
                                                      num_workers=0)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                                                  num_workers=0)

    # optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    # sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)
    # 使用了自己实现的 Optimizer
    optim = NoamOpt(args.emb_size, args.factor, len(train_dataloader) * args.warmup,
                    torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    # optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001)
    epoch = 0

    # mean=train_dataset[:]['src'][:,1:,2:4].mean((0,1))
    # mean = torch.cat((train_dataset[:]['src'][:, 1:, 2:4], train_dataset[:]['trg'][:, :, 2:4]), 1).mean((0, 1))
    # std=train_dataset[:]['src'][:,1:,2:4].std((0,1))
    # std = torch.cat((train_dataset[:]['src'][:, 1:, 2:4], train_dataset[:]['trg'][:, :, 2:4]), 1).std((0, 1))
    # means = []
    # stds = []
    # for i in np.unique(train_dataset[:]['dataset']):
    #     ind = train_dataset[:]['dataset'] == i
    #     means.append(
    #         torch.cat((train_dataset[:]['src'][ind, 1:, 2:4], train_dataset[:]['trg'][ind, :, 2:4]), 1).mean((0, 1)))
    #     stds.append(
    #         torch.cat((train_dataset[:]['src'][ind, 1:, 2:4], train_dataset[:]['trg'][ind, :, 2:4]), 1).std((0, 1)))
    mean = torch.cat((train_dataset[:]['src'][:, 1:, 3:6], train_dataset[:]['trg'][:, :, 3:6]), 1).mean((0, 1))
    std = torch.cat((train_dataset[:]['src'][:, 1:, 3:6], train_dataset[:]['trg'][:, :, 3:6]), 1).std((0, 1))

    scipy.io.savemat(f'models/Individual/{args.name}/norm.mat', {'mean': mean.cpu().numpy(), 'std': std.cpu().numpy()})

    print("\033[1;32mStart Training...\033[0m")

    while epoch < args.max_epoch:
        epoch_train_loss = 0
        epoch_validate_loss = 0
        model.train()

        train_batch_bar = tqdm([(id, batch) for id, batch in enumerate(train_dataloader)])

        all_batch_loss = []
        all_validate_loss = []
        all_test_loss = []

        for id_b, batch in train_batch_bar:
            optim.optimizer.zero_grad()
            # (batch_size, 19, 3)
            inp = (batch['src'][:, 1:, 3:6].to(device) - mean.to(device)) / std.to(device)
            # (batch_size, 11, 2)
            target = (batch['trg'][:, :-1, 3:6].to(device) - mean.to(device)) / std.to(device)
            target_c = torch.zeros((target.shape[0], target.shape[1], 1)).to(device)
            # 第三维合并
            target = torch.cat((target, target_c), -1)
            start_of_seq = torch.Tensor([0, 0, 0, 1]).unsqueeze(0).unsqueeze(1).repeat(target.shape[0], 1, 1).to(device)

            # (batch_size, 20, 4)
            decoder_input = torch.cat((start_of_seq, target), 1)

            # (input_shape[0], 1, input_shape[1] | 19)
            src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
            # (batch_size, 20, 20)
            trg_att = subsequent_mask(decoder_input.shape[1]).repeat(decoder_input.shape[0], 1, 1).to(device)

            # (batch_size, 20, 4)
            prediction = model(inp, decoder_input, src_att, trg_att)

            # 计算两个矩阵的成对距离
            loss = F.pairwise_distance(prediction[:, :, 0:3].contiguous().view(-1, 2),
                                       ((batch['trg'][:, :, 3:6].to(device) - mean.to(device)) / std.to(
                                           device)).contiguous().view(-1, 2).to(device)).mean() + torch.mean(
                torch.abs(prediction[:, :, 3]))
            loss.backward()
            optim.step()
            epoch_train_loss += loss.item()
            all_batch_loss.append(loss.item())
            train_batch_bar.set_description("train epoch %03i/%03i  loss: %7.4f  batch_loss: %7.4f" % (
                epoch + 1, args.max_epoch, loss.item(), sum(all_batch_loss) / len(all_batch_loss)))
            # print("train epoch %03i/%03i  batch %04i / %04i loss: %7.4f" % (
            #     epoch, args.max_epoch, id_b, len(train_dataloader), loss.item()))
        # sched.step()
        log.add_scalar('Loss/train', epoch_train_loss / len(train_dataloader), epoch)

        with torch.no_grad():
            model.eval()

            # model.eval()
            gt = []
            pr = []
            input_ = []
            validate_batch_bar = tqdm([(id, batch) for id, batch in enumerate(validate_dataloader)])

            for id_b, batch in validate_batch_bar:
                input_.append(batch['src'])
                gt.append(batch['trg'][:, :, 0:3])

                inp = (batch['src'][:, 1:, 3:6].to(device) - mean.to(device)) / std.to(device)
                src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
                start_of_seq = torch.Tensor([0, 0, 0, 1]).unsqueeze(0).unsqueeze(1).repeat(inp.shape[0], 1, 1).to(
                    device)
                decoder_input = start_of_seq

                # out = None

                # 进行 preds 步预测
                for i in range(args.preds):
                    trg_att = subsequent_mask(decoder_input.shape[1]).repeat(decoder_input.shape[0], 1, 1).to(device)
                    out = model(inp, decoder_input, src_att, trg_att)
                    decoder_input = torch.cat((decoder_input, out[:, -1:, :]), 1)

                preds_tr_b = (decoder_input[:, 1:, 0:3] * std.to(device) + mean.to(device)).cpu().numpy().cumsum(1) + \
                             batch['src'][:, -1:, 0:3].cpu().numpy()
                pr.append(preds_tr_b)

                val_loss = F.pairwise_distance(decoder_input[:, 1:, 0:3].contiguous().view(-1, 2),
                                               ((batch['trg'][:, :, 3:6].to(device) - mean.to(device)) / std.to(
                                                   device)).contiguous().view(-1, 2).to(device)).mean() + torch.mean(
                    torch.abs(out[:, :, 3]))
                all_validate_loss.append(val_loss)
                epoch_validate_loss += val_loss
                validate_batch_bar.set_description("val epoch %03i/%03i    loss: %7.4f    avg_loss: %7.4f" % (
                    epoch + 1, args.max_epoch, val_loss.item(), sum(all_validate_loss) / len(all_validate_loss)))
                # print("val epoch %03i/%03i  batch %04i / %04i  loss: %7.4f" % (
                #     epoch, args.max_epoch, id_b, len(validate_dataloader), val_loss.item()))

            log.add_scalar('Loss/validation', epoch_validate_loss / len(validate_dataloader), epoch)
            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            # MAD: 平均绝对误差
            # FAD: 最后一个点的误差
            # errs:
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)
            log.add_scalar('validation/MAD', mad, epoch)
            log.add_scalar('validation/FAD', fad, epoch)

            if args.evaluate:

                model.eval()
                gt = []
                pr = []
                input_ = []

                test_batch_bar = tqdm([(_id, batch) for _id, batch in enumerate(test_dataloader)])

                for id_b, batch in test_batch_bar:
                    input_.append(batch['src'])
                    gt.append(batch['trg'][:, :, 0:3])

                    inp = (batch['src'][:, 1:, 3:6].to(device) - mean.to(device)) / std.to(device)
                    src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
                    start_of_seq = torch.Tensor([0, 0, 0, 1]).unsqueeze(0).unsqueeze(1).repeat(inp.shape[0], 1, 1).to(
                        device)
                    decoder_input = start_of_seq

                    for i in range(args.preds):
                        trg_att = subsequent_mask(decoder_input.shape[1]).repeat(decoder_input.shape[0], 1, 1).to(
                            device)
                        out = model(inp, decoder_input, src_att, trg_att)
                        decoder_input = torch.cat((decoder_input, out[:, -1:, :]), 1)

                    # 注意这个地方,要加上输入的最后一个点,之后的预测轨迹是在该点的基础上进行预测的
                    # 而 -1: 则保证了维度信息
                    # 注意 `cumsum` 的累加位置
                    preds_tr_b = (decoder_input[:, 1:, 0:3] * std.to(device) + mean.to(device)).cpu().numpy().cumsum(
                        1) + batch['src'][:, -1:, 0:3].cpu().numpy()
                    pr.append(preds_tr_b)
                    test_batch_bar.set_description("test epoch %03i/%03i" % (
                        epoch + 1, args.max_epoch))

                gt = np.concatenate(gt, 0)
                pr = np.concatenate(pr, 0)
                input_ = np.concatenate(input_, 0)
                mad, fad, errs = baselineUtils.distance_metrics(gt, pr)

                log.add_scalar('eval/DET_mad', mad, epoch)
                log.add_scalar('eval/DET_fad', fad, epoch)

                scipy.io.savemat(f"output/Individual/{args.name}/det_{epoch}.mat",
                                 {
                                     'input': input_,
                                     'gt': gt,
                                     'pr': pr,
                                 })

        if (epoch + 1) % args.save_step == 0:
            torch.save(model.state_dict(), f'models/Individual/{args.name}/{epoch:05d}.pth')

        epoch += 1
Beispiel #10
0
def main():
    parser=argparse.ArgumentParser(description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder',type=str,default='datasets')
    parser.add_argument('--dataset_name',type=str,default='zara1')
    parser.add_argument('--obs',type=int,default=8)
    parser.add_argument('--preds',type=int,default=12)
    parser.add_argument('--emb_size',type=int,default=512)
    parser.add_argument('--heads',type=int, default=8)
    parser.add_argument('--layers',type=int,default=6)
    parser.add_argument('--dropout',type=float,default=0.1)
    parser.add_argument('--cpu',action='store_true')
    parser.add_argument('--val_size',type=int, default=0)
    parser.add_argument('--verbose',action='store_true')
    parser.add_argument('--max_epoch',type=int, default=1500)
    parser.add_argument('--batch_size',type=int,default=1)
    parser.add_argument('--validation_epoch_start', type=int, default=30)
    parser.add_argument('--resume_train',action='store_true')
    parser.add_argument('--delim',type=str,default='\t')
    parser.add_argument('--name', type=str, default="zara1")
    parser.add_argument('--factor', type=float, default=1.)
    parser.add_argument('--save_step', type=int, default=1)
    parser.add_argument('--warmup', type=int, default=10)
    parser.add_argument('--evaluate', type=bool, default=True)
    parser.add_argument('--model_pth', type=str)




    args=parser.parse_args()
    model_name=args.name

    #device=torch.device("cuda")

    #if args.cpu or not torch.cuda.is_available():
    device=torch.device("cpu")

    args.verbose=True

    test_dataset,_ =  baselineUtils.create_dataset(args.dataset_folder,args.dataset_name,0,args.obs,args.preds,delim=args.delim,train=False,eval=True,verbose=args.verbose)


    import individual_TF
    model=individual_TF.IndividualTF(2, 3, 3, N=args.layers,
                   d_model=args.emb_size, d_ff=2048, h=args.heads, dropout=args.dropout,mean=[0,0],std=[0,0]).to(device)

    
    model.load_state_dict(torch.load(f'models/Individual/my_data_train/00013.pth'))
    test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

    model.eval()
    gt = []
    pr = []
    inp_ = []
    peds = []
    frames = []
    dt = []
                
    for id_b,batch in enumerate(test_dl):
        #print(batch['src'].shape)
        #inp_.append(batch['src'])
        gt.append(batch['trg'][:,:,0:2])
        #frames.append(batch['frames'])
        #peds.append(batch['peds'])
        #dt.append(batch['dataset'])

        inp = batch['src'][:, 1:, 2:4].to(device) #- mean.to(device)) / std.to(device)
        #print(inp.shape)
        src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
        start_of_seq = torch.Tensor([0, 0, 1]).unsqueeze(0).unsqueeze(1).repeat(inp.shape[0], 1, 1).to(
                        device)
        # print("start of seq")
        # print(start_of_seq[0])
        dec_inp=start_of_seq

        for i in range(args.preds):
            trg_att = subsequent_mask(dec_inp.shape[1]).repeat(dec_inp.shape[0], 1, 1).to(device)
            # print("src_att shape")
            # print(src_att.shape)
            # print("trg_att shape")
            # print(trg_att.shape)
            out = model(inp, dec_inp, src_att, trg_att)
            # print("out shape")
            # print(out.shape)
            # print("-----------")
            dec_inp=torch.cat((dec_inp,out[:,-1:,:]),1)


        print("batch['src']")
        print(batch['src'].shape)
        preds_tr_b=(dec_inp[:,1:,0:2]).cpu().detach().numpy().cumsum(1)+batch['src'][:,-1:,0:2].cpu().detach().numpy()
        #print(preds_tr_b[1])
        pr.append(preds_tr_b)
        # print("test epoch %03i/%03i  batch %04i / %04i" % (
        #         epoch, args.max_epoch, id_b, len(test_dl)))
    gt = np.concatenate(gt, 0)
    #dt_names = test_dataset.data['dataset_name']
    pr = np.concatenate(pr, 0)
    mad, fad, errs = baselineUtils.distance_metrics(gt, pr)
Beispiel #11
0
def main():
    parser = argparse.ArgumentParser(
        description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder', type=str, default='datasets')
    parser.add_argument('--dataset_name', type=str, default='zara1')
    parser.add_argument('--obs', type=int, default=8)
    parser.add_argument('--preds', type=int, default=12)
    parser.add_argument('--emb_size', type=int, default=512)
    parser.add_argument('--heads', type=int, default=8)
    parser.add_argument('--layers', type=int, default=6)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--output_folder', type=str, default='Output')
    parser.add_argument('--val_size', type=int, default=0)
    parser.add_argument('--gpu_device', type=str, default="0")
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--max_epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--validation_epoch_start', type=int, default=30)
    parser.add_argument('--resume_train', action='store_true')
    parser.add_argument('--delim', type=str, default='\t')
    parser.add_argument('--name', type=str, default="zara1")
    parser.add_argument('--factor', type=float, default=1.)
    parser.add_argument('--evaluate', type=bool, default=True)
    parser.add_argument('--save_step', type=int, default=1)

    args = parser.parse_args()
    model_name = args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/QuantizedTF')
    except:
        pass
    try:
        os.mkdir(f'models/QuantizedTF')
    except:
        pass

    try:
        os.mkdir(f'output/QuantizedTF/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/QuantizedTF/{args.name}')
    except:
        pass

    log = SummaryWriter('logs/%s' % model_name)

    #os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
    device = torch.device("cuda")

    if args.cpu or not torch.cuda.is_available():
        device = torch.device("cpu")

    args.verbose = True

    ## creation of the dataloaders for train and validation
    if args.val_size == 0:
        train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                        args.dataset_name,
                                                        0,
                                                        args.obs,
                                                        args.preds,
                                                        delim=args.delim,
                                                        train=True,
                                                        verbose=args.verbose)
        val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                      args.dataset_name,
                                                      0,
                                                      args.obs,
                                                      args.preds,
                                                      delim=args.delim,
                                                      train=False,
                                                      verbose=args.verbose)
    else:
        train_dataset, val_dataset = baselineUtils.create_dataset(
            args.dataset_folder,
            args.dataset_name,
            args.val_size,
            args.obs,
            args.preds,
            delim=args.delim,
            train=True,
            verbose=args.verbose)

    test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                   args.dataset_name,
                                                   0,
                                                   args.obs,
                                                   args.preds,
                                                   delim=args.delim,
                                                   train=False,
                                                   eval=True,
                                                   verbose=args.verbose)

    mat = scipy.io.loadmat(
        os.path.join(args.dataset_folder, args.dataset_name, "clusters.mat"))
    clusters = mat['centroids']

    import quantized_TF
    model = quantized_TF.QuantizedTF(clusters.shape[0],
                                     clusters.shape[0] + 1,
                                     clusters.shape[0],
                                     N=args.layers,
                                     d_model=args.emb_size,
                                     d_ff=1024,
                                     h=args.heads,
                                     dropout=args.dropout).to(device)

    tr_dl = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=0)
    val_dl = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=0)
    test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=0)

    #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)
    optim = NoamOpt(
        args.emb_size, args.factor,
        len(tr_dl) * 5,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))
    #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001)
    epoch = 0

    while epoch < args.max_epoch:
        epoch_loss = 0
        model.train()

        for id_b, batch in enumerate(tr_dl):

            optim.optimizer.zero_grad()
            scale = np.random.uniform(0.5, 4)
            #rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])
            n_in_batch = batch['src'].shape[0]
            speeds_inp = batch['src'][:, 1:, 2:4] * scale
            inp = torch.tensor(
                scipy.spatial.distance.cdist(speeds_inp.reshape(-1, 2),
                                             clusters).argmin(axis=1).reshape(
                                                 n_in_batch, -1)).to(device)
            speeds_trg = batch['trg'][:, :, 2:4] * scale
            target = torch.tensor(
                scipy.spatial.distance.cdist(speeds_trg.reshape(-1, 2),
                                             clusters).argmin(axis=1).reshape(
                                                 n_in_batch, -1)).to(device)
            src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
            trg_att = subsequent_mask(target.shape[1]).repeat(
                n_in_batch, 1, 1).to(device)
            start_of_seq = torch.tensor(
                [clusters.shape[0]]).repeat(n_in_batch).unsqueeze(1).to(device)
            dec_inp = torch.cat((start_of_seq, target[:, :-1]), 1)

            out = model(inp, dec_inp, src_att, trg_att)

            loss = F.cross_entropy(out.view(-1, out.shape[-1]),
                                   target.view(-1),
                                   reduction='mean')
            loss.backward()
            optim.step()
            print("epoch %03i/%03i  frame %04i / %04i loss: %7.4f" %
                  (epoch, args.max_epoch, id_b, len(tr_dl), loss.item()))
            epoch_loss += loss.item()
        #sched.step()
        log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch)
        with torch.no_grad():
            model.eval()

            gt = []
            pr = []
            val_loss = 0
            step = 0
            for batch in val_dl:
                # rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])
                n_in_batch = batch['src'].shape[0]
                speeds_inp = batch['src'][:, 1:, 2:4]
                inp = torch.tensor(
                    scipy.spatial.distance.cdist(
                        speeds_inp.contiguous().reshape(-1, 2),
                        clusters).argmin(axis=1).reshape(n_in_batch,
                                                         -1)).to(device)
                speeds_trg = batch['trg'][:, :, 2:4]
                target = torch.tensor(
                    scipy.spatial.distance.cdist(
                        speeds_trg.contiguous().reshape(-1, 2),
                        clusters).argmin(axis=1).reshape(n_in_batch,
                                                         -1)).to(device)
                src_att = torch.ones(
                    (inp.shape[0], 1, inp.shape[1])).to(device)
                trg_att = subsequent_mask(target.shape[1]).repeat(
                    n_in_batch, 1, 1).to(device)
                start_of_seq = torch.tensor([
                    clusters.shape[0]
                ]).repeat(n_in_batch).unsqueeze(1).to(device)
                dec_inp = torch.cat((start_of_seq, target[:, :-1]), 1)

                out = model(inp, dec_inp, src_att, trg_att)

                loss = F.cross_entropy(out.contiguous().view(
                    -1, out.shape[-1]),
                                       target.contiguous().view(-1),
                                       reduction='mean')

                print("val epoch %03i/%03i  frame %04i / %04i loss: %7.4f" %
                      (epoch, args.max_epoch, step, len(val_dl), loss.item()))
                val_loss += loss.item()
                step += 1

            log.add_scalar('validation/loss', val_loss / len(val_dl), epoch)

            if args.evaluate:
                # DETERMINISTIC MODE
                model.eval()
                model.eval()
                gt = []
                pr = []
                inp_ = []
                peds = []
                frames = []
                dt = []
                for batch in test_dl:

                    inp_.append(batch['src'][:, :, 0:2])
                    gt.append(batch['trg'][:, :, 0:2])
                    frames.append(batch['frames'])
                    peds.append(batch['peds'])
                    dt.append(batch['dataset'])

                    n_in_batch = batch['src'].shape[0]
                    speeds_inp = batch['src'][:, 1:, 2:4]
                    gt_b = batch['trg'][:, :, 0:2]
                    inp = torch.tensor(
                        scipy.spatial.distance.cdist(
                            speeds_inp.reshape(-1, 2),
                            clusters).argmin(axis=1).reshape(n_in_batch,
                                                             -1)).to(device)
                    src_att = torch.ones(
                        (inp.shape[0], 1, inp.shape[1])).to(device)
                    trg_att = subsequent_mask(target.shape[1]).repeat(
                        n_in_batch, 1, 1).to(device)
                    start_of_seq = torch.tensor([
                        clusters.shape[0]
                    ]).repeat(n_in_batch).unsqueeze(1).to(device)
                    dec_inp = start_of_seq

                    for i in range(args.preds):
                        trg_att = subsequent_mask(dec_inp.shape[1]).repeat(
                            n_in_batch, 1, 1).to(device)
                        out = model(inp, dec_inp, src_att, trg_att)
                        dec_inp = torch.cat(
                            (dec_inp, out[:, -1:].argmax(dim=2)), 1)

                    preds_tr_b = clusters[dec_inp[:, 1:].cpu().numpy()].cumsum(
                        1) + batch['src'][:, -1:, 0:2].cpu().numpy()
                    pr.append(preds_tr_b)

                peds = np.concatenate(peds, 0)
                frames = np.concatenate(frames, 0)
                dt = np.concatenate(dt, 0)
                gt = np.concatenate(gt, 0)
                dt_names = test_dataset.data['dataset_name']
                pr = np.concatenate(pr, 0)
                mad, fad, errs = baselineUtils.distance_metrics(gt, pr)

                log.add_scalar('eval/DET_mad', mad, epoch)
                log.add_scalar('eval/DET_fad', fad, epoch)

                scipy.io.savemat(
                    f"output/QuantizedTF/{args.name}/{epoch:05d}.mat", {
                        'input': inp,
                        'gt': gt,
                        'pr': pr,
                        'peds': peds,
                        'frames': frames,
                        'dt': dt,
                        'dt_names': dt_names
                    })

                # MULTI MODALITY

                if False:
                    num_samples = 20

                    model.eval()
                    gt = []
                    pr_all = {}
                    for sam in range(num_samples):
                        pr_all[sam] = []
                    for batch in test_dl:
                        # rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])
                        n_in_batch = batch['src'].shape[0]
                        speeds_inp = batch['src'][:, 1:, 2:4]
                        gt_b = batch['trg'][:, :, 0:2]
                        gt.append(gt_b)
                        inp = torch.tensor(
                            scipy.spatial.distance.cdist(
                                speeds_inp.reshape(-1, 2),
                                clusters).argmin(axis=1).reshape(
                                    n_in_batch, -1)).to(device)
                        src_att = torch.ones(
                            (inp.shape[0], 1, inp.shape[1])).to(device)
                        trg_att = subsequent_mask(target.shape[1]).repeat(
                            n_in_batch, 1, 1).to(device)
                        start_of_seq = torch.tensor([
                            clusters.shape[0]
                        ]).repeat(n_in_batch).unsqueeze(1).to(device)

                        for sam in range(num_samples):
                            dec_inp = start_of_seq

                            for i in range(args.preds):
                                trg_att = subsequent_mask(
                                    dec_inp.shape[1]).repeat(n_in_batch, 1,
                                                             1).to(device)
                                out = model.predict(inp, dec_inp, src_att,
                                                    trg_att)
                                h = out[:, -1]
                                dec_inp = torch.cat(
                                    (dec_inp, torch.multinomial(h, 1)), 1)

                            preds_tr_b = clusters[dec_inp[:, 1:].cpu().numpy(
                            )].cumsum(1) + batch['src'][:, -1:,
                                                        0:2].cpu().numpy()

                            pr_all[sam].append(preds_tr_b)

                    gt = np.concatenate(gt, 0)
                    #pr=np.concatenate(pr,0)
                    samp = {}
                    for k in pr_all.keys():
                        samp[k] = {}
                        samp[k]['pr'] = np.concatenate(pr_all[k], 0)
                        samp[k]['mad'], samp[k]['fad'], samp[k][
                            'err'] = baselineUtils.distance_metrics(
                                gt, samp[k]['pr'])

                    ev = [samp[i]['err'] for i in range(num_samples)]
                    e20 = np.stack(ev, -1)
                    mad_samp = e20.mean(1).min(-1).mean()
                    fad_samp = e20[:, -1].min(-1).mean()
                    #mad,fad,errs=baselineUtils.distance_metrics(gt,pr)

                    log.add_scalar('eval/MM_mad', mad_samp, epoch)
                    log.add_scalar('eval/MM_fad', fad_samp, epoch)

            if epoch % args.save_step == 0:
                torch.save(model.state_dict(),
                           f'models/QuantizedTF/{args.name}/{epoch:05d}.pth')

        epoch += 1

    ab = 1
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser(
        description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder', type=str, default='datasets')
    parser.add_argument('--dataset_name', type=str, default='eth')
    parser.add_argument('--obs', type=int, default=8)
    parser.add_argument('--preds', type=int, default=12)
    parser.add_argument('--emb_size', type=int, default=1024)
    parser.add_argument('--heads', type=int, default=8)
    parser.add_argument('--layers', type=int, default=6)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--output_folder', type=str, default='Output')
    parser.add_argument('--val_size', type=int, default=50)
    parser.add_argument('--gpu_device', type=str, default="0")
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--max_epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--validation_epoch_start', type=int, default=30)
    parser.add_argument('--resume_train', action='store_true')
    parser.add_argument('--delim', type=str, default='\t')
    parser.add_argument('--name', type=str, default="eth_0.1")
    parser.add_argument('--factor', type=float, default=0.1)
    parser.add_argument('--save_step', type=int, default=1)

    args = parser.parse_args()
    model_name = args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/BERT_quantized')
    except:
        pass
    try:
        os.mkdir(f'models/BERT_quantized')
    except:
        pass

    try:
        os.mkdir(f'output/BERT_quantized/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/BERT_quantized/{args.name}')
    except:
        pass

    log = SummaryWriter('logs/BERT_quant_%s' % model_name)

    log.add_scalar('eval/mad', 0, 0)
    log.add_scalar('eval/fad', 0, 0)

    try:
        os.mkdir(args.name)
    except:
        pass

    device = torch.device("cuda")
    if args.cpu or not torch.cuda.is_available():
        device = torch.device("cpu")

    args.verbose = True

    ## creation of the dataloaders for train and validation
    train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                    args.dataset_name,
                                                    0,
                                                    args.obs,
                                                    args.preds,
                                                    delim=args.delim,
                                                    train=True,
                                                    verbose=args.verbose)
    val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                  args.dataset_name,
                                                  0,
                                                  args.obs,
                                                  args.preds,
                                                  delim=args.delim,
                                                  train=False,
                                                  verbose=args.verbose)
    test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                   args.dataset_name,
                                                   0,
                                                   args.obs,
                                                   args.preds,
                                                   delim=args.delim,
                                                   train=False,
                                                   eval=True,
                                                   verbose=args.verbose)

    #model.set_output_embeddings(GeneratorTS(1024,2))

    tr_dl = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=0)
    val_dl = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=0)
    test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=0)

    #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)
    #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001)
    epoch = 0
    mat = scipy.io.loadmat(
        os.path.join(args.dataset_folder, args.dataset_name, "clusters.mat"))
    clusters = mat['centroids']
    config = transformers.BertConfig(vocab_size=clusters.shape[0] + 1)
    gen = nn.Linear(config.hidden_size, clusters.shape[0]).to(device)
    model = transformers.BertModel(config).to(device)
    gen = nn.Linear(config.hidden_size, clusters.shape[0]).to(device)
    optim = NoamOpt(
        args.emb_size, args.factor,
        len(tr_dl) * 5,
        torch.optim.Adam(list(model.parameters()) + list(gen.parameters()),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))

    mean = train_dataset[:]['src'][:, :, 2:4].mean((0, 1)) * 0
    std = train_dataset[:]['src'][:, :, 2:4].std((0, 1)) * 0 + 1

    while epoch < args.max_epoch:
        epoch_loss = 0
        model.train()

        for id_b, batch in enumerate(tr_dl):
            optim.optimizer.zero_grad()
            scale = np.random.uniform(0.5, 2)
            # rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])
            n_in_batch = batch['src'].shape[0]
            speeds_inp = batch['src'][:, 1:, 2:4] * scale
            inp = torch.tensor(
                scipy.spatial.distance.cdist(speeds_inp.reshape(-1, 2),
                                             clusters).argmin(axis=1).reshape(
                                                 n_in_batch, -1)).to(device)
            speeds_trg = batch['trg'][:, :, 2:4] * scale
            target = torch.tensor(
                scipy.spatial.distance.cdist(speeds_trg.reshape(-1, 2),
                                             clusters).argmin(axis=1).reshape(
                                                 n_in_batch, -1)).to(device)
            src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device)
            trg_att = subsequent_mask(target.shape[1]).repeat(
                n_in_batch, 1, 1).to(device)
            dec_inp = torch.tensor([clusters.shape[0]
                                    ]).repeat(n_in_batch,
                                              args.preds).to(device)
            bert_inp = torch.cat((inp, dec_inp), 1)

            out = gen(
                model(bert_inp,
                      attention_mask=torch.ones(
                          bert_inp.shape[0], bert_inp.shape[1]).to(device))[0])

            loss = F.cross_entropy(out.view(-1, out.shape[-1]),
                                   torch.cat((inp, target), 1).view(-1),
                                   reduction='mean')
            loss.backward()
            optim.step()
            print("epoch %03i/%03i  frame %04i / %04i loss: %7.4f" %
                  (epoch, args.max_epoch, id_b, len(tr_dl), loss.item()))
            epoch_loss += loss.item()
        #sched.step()
        log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch)
        with torch.no_grad():
            model.eval()
            gt = []
            pr = []
            for batch in val_dl:
                gt_b = batch['trg'][:, :, 0:2]

                optim.optimizer.zero_grad()
                # rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])
                n_in_batch = batch['src'].shape[0]
                speeds_inp = batch['src'][:, 1:, 2:4]
                inp = torch.tensor(
                    scipy.spatial.distance.cdist(
                        speeds_inp.reshape(-1, 2),
                        clusters).argmin(axis=1).reshape(n_in_batch,
                                                         -1)).to(device)

                dec_inp = torch.tensor([clusters.shape[0]
                                        ]).repeat(n_in_batch,
                                                  args.preds).to(device)
                bert_inp = torch.cat((inp, dec_inp), 1)

                out = gen(
                    model(bert_inp,
                          attention_mask=torch.ones(
                              bert_inp.shape[0],
                              bert_inp.shape[1]).to(device))[0])

                F.softmax(out)
                preds_tr_b = clusters[F.softmax(out, dim=-1).argmax(
                    dim=-1).cpu().numpy()][:, -args.preds:].cumsum(
                        axis=1) + batch['src'][:, -1:, 0:2].cpu().numpy()
                gt.append(gt_b)
                pr.append(preds_tr_b)

            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)

            log.add_scalar('validation/mad', mad, epoch)
            log.add_scalar('validation/fad', fad, epoch)

            model.eval()
            gt = []
            pr = []
            for batch in test_dl:
                gt_b = batch['trg'][:, :, 0:2]

                optim.optimizer.zero_grad()
                # rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])
                n_in_batch = batch['src'].shape[0]
                speeds_inp = batch['src'][:, 1:, 2:4]
                inp = torch.tensor(
                    scipy.spatial.distance.cdist(
                        speeds_inp.reshape(-1, 2),
                        clusters).argmin(axis=1).reshape(n_in_batch,
                                                         -1)).to(device)

                dec_inp = torch.tensor([clusters.shape[0]
                                        ]).repeat(n_in_batch,
                                                  args.preds).to(device)
                bert_inp = torch.cat((inp, dec_inp), 1)

                out = gen(
                    model(bert_inp,
                          attention_mask=torch.ones(
                              bert_inp.shape[0],
                              bert_inp.shape[1]).to(device))[0])

                F.softmax(out)
                preds_tr_b = clusters[F.softmax(out, dim=-1).argmax(
                    dim=-1).cpu().numpy()][:, -args.preds:].cumsum(
                        axis=1) + batch['src'][:, -1:, 0:2].cpu().numpy()
                gt.append(gt_b)
                pr.append(preds_tr_b)

            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)
            if epoch % args.save_step == 0:
                torch.save(
                    model.state_dict(),
                    "models/BERT_quantized/%s/model_%03i.pth" %
                    (args.name, epoch))
                torch.save(
                    gen.state_dict(), "models/BERT_quantized/%s/gen_%03i.pth" %
                    (args.name, epoch))

            log.add_scalar('eval/DET_mad', mad, epoch)
            log.add_scalar('eval/DET_fad', fad, epoch)

        epoch += 1

    ab = 1