def main(args): if args.ckpt_load_iter == args.max_iter: print("Initializing test dataset") solver = Solver(args) print('--------------------', args.dataset_name, '----------------------') test_path = os.path.join(args.dataset_dir, args.dataset_name, 'test') args.batch_size = 5 _, test_loader = data_loader(args, test_path, shuffle=True) solver.evaluate_dist(test_loader, loss=False) solver.check_feat(test_loader) gh = True print("GEN HEAT MAP: ", gh) traj_path = 'traj_zD_20_dr_mlp_0.3_dr_rnn_0.25_enc_hD_64_dec_hD_128_mlpD_256_map_featD_32_map_mlpD_256_lr_0.001_klw_50.0_ll_prior_w_1.0_zfb_0.07_run_4' traj_iter = '13000' traj_ckpt = { 'ckpt_dir': os.path.join('ckpts', traj_path), 'iter': traj_iter } print('===== TRAJECTORY:', traj_ckpt) lg_path = 'lgcvae_enc_block_1_fcomb_block_2_wD_10_lr_0.001_lg_klw_1_a_0.25_r_2.0_fb_2.0_anneal_e_0_load_e_1_run_21' lg_iter = '26000' lg_ckpt = {'ckpt_dir': os.path.join('ckpts', lg_path), 'iter': lg_iter} print('===== LG CVAE:', lg_iter) solver.pretrain_load_checkpoint(traj_ckpt, lg_ckpt) # solver.plot_traj_var(test_loader) # solver.evaluate_dist_gt_goal(test_loader) # solver.check_feat(test_loader) lg_num = 5 traj_num = 4 ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, \ sg_ade_min, sg_ade_avg, sg_ade_std, \ lg_fde_min, lg_fde_avg, lg_fde_std = solver.all_evaluation(test_loader, lg_num=lg_num, traj_num=traj_num, generate_heat=gh) print('lg_num: ', lg_num, ' // traj_num: ', traj_num) print('ade min: ', ade_min) print('ade avg: ', ade_avg) print('ade std: ', ade_std) print('fde min: ', fde_min) print('fde avg: ', fde_avg) print('fde std: ', fde_std) print('sg_ade_min: ', sg_ade_min) print('sg_ade_avg: ', sg_ade_avg) print('sg_ade_std: ', sg_ade_std) print('lg_fde_min: ', lg_fde_min) print('lg_fde_avg: ', lg_fde_avg) print('lg_fde_std: ', lg_fde_std) print('------------------------------------------') lg_num = 20 traj_num = 1 ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, \ sg_ade_min, sg_ade_avg, sg_ade_std, \ lg_fde_min, lg_fde_avg, lg_fde_std = solver.all_evaluation(test_loader, lg_num=lg_num, traj_num=traj_num, generate_heat=gh) print('lg_num: ', lg_num, ' // traj_num: ', traj_num) print('ade min: ', ade_min) print('ade avg: ', ade_avg) print('ade std: ', ade_std) print('fde min: ', fde_min) print('fde avg: ', fde_avg) print('fde std: ', fde_std) print('sg_ade_min: ', sg_ade_min) print('sg_ade_avg: ', sg_ade_avg) print('sg_ade_std: ', sg_ade_std) print('lg_fde_min: ', lg_fde_min) print('lg_fde_avg: ', lg_fde_avg) print('lg_fde_std: ', lg_fde_std) print('------------------------------------------') ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, \ sg_ade_min, sg_ade_avg, sg_ade_std, \ lg_fde_min, lg_fde_avg, lg_fde_std = solver.evaluate_dist(test_loader, loss=False) print('ade min: ', ade_min) print('ade avg: ', ade_avg) print('ade std: ', ade_std) print('fde min: ', fde_min) print('fde avg: ', fde_avg) print('fde std: ', fde_std) print('sg_ade_min: ', sg_ade_min) print('sg_ade_avg: ', sg_ade_avg) print('sg_ade_std: ', sg_ade_std) print('lg_fde_min: ', lg_fde_min) print('lg_fde_avg: ', lg_fde_avg) print('lg_fde_std: ', lg_fde_std) print('------------------------------------------') else: solver = Solver(args) solver.train()
def main(args): if args.ckpt_load_iter == args.max_iter: print("Initializing test dataset") solver = Solver(args) print('--------------------', args.dataset_name, '----------------------') args.batch_size = 4 # cfg = Config('nuscenes', False, create_dirs=True) # torch.set_default_dtype(torch.float32) # log = open('log.txt', 'a+') # test_loader = data_generator(cfg, log, split='test', phase='testing', # batch_size=args.batch_size, device=args.device, scale=args.scale, shuffle=False) _, test_loader = data_loader(args, args.dataset_dir, 'test', shuffle=False) # solver.load_checkpoint() # solver.check_feat(test_loader) # solver.evaluate_lg(test_loader, num_gen=3) # solver.evaluate_each(test_loader) # solver.collision_stat(test_loader) solver.evaluate_dist(test_loader, loss=True) # # fde_min, fde_avg, fde_std = solver.evaluate_dist(test_loader, loss=False) # print(fde_min) # print(fde_avg) # print(fde_std) gh = True print("GEN HEAT MAP: ", gh) traj_path = 'sdd.traj_zD_20_dr_mlp_0.3_dr_rnn_0.25_enc_hD_64_dec_hD_128_mlpD_256_map_featD_32_map_mlpD_256_lr_0.001_klw_50.0_ll_prior_w_1.0_zfb_0.07_scale_100.0_num_sg_3_run_200' traj_iter = '15000' traj_ckpt = {'ckpt_dir': os.path.join('ckpts', traj_path), 'iter': traj_iter} print('===== TRAJECTORY:', traj_ckpt) # lg_path = 'lgcvae_enc_block_1_fcomb_block_2_wD_20_lr_0.001_lg_klw_1_a_0.25_r_2.0_fb_0.5_anneal_e_0_load_e_1_run_24' # lg_iter = '57100' lg_path = 'sdd.lgcvae_enc_block_1_fcomb_block_2_wD_20_lr_0.0001_lg_klw_1.0_a_0.25_r_2.0_fb_0.5_anneal_e_0_aug_1_run_181' lg_iter = '43000' lg_ckpt = {'ckpt_dir': os.path.join('ckpts', lg_path), 'iter': lg_iter} print('===== LG CVAE:', lg_ckpt) solver.pretrain_load_checkpoint(traj_ckpt, lg_ckpt) # solver.check_feat(test_loader) # solver.plot_traj_var(test_loader) # solver.evaluate_dist_gt_goal(test_loader) # solver.check_feat(test_loader) lg_num=20 traj_num=1 ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, \ sg_ade_min, sg_ade_avg, sg_ade_std, \ lg_fde_min, lg_fde_avg, lg_fde_std = solver.all_evaluation(test_loader, lg_num=lg_num, traj_num=traj_num, generate_heat=gh) print('lg_num: ', lg_num, ' // traj_num: ', traj_num) print('ade min: ', ade_min) print('ade avg: ', ade_avg) print('ade std: ', ade_std) print('fde min: ', fde_min) print('fde avg: ', fde_avg) print('fde std: ', fde_std) print('sg_ade_min: ', sg_ade_min) print('sg_ade_avg: ', sg_ade_avg) print('sg_ade_std: ', sg_ade_std) print('lg_fde_min: ', lg_fde_min) print('lg_fde_avg: ', lg_fde_avg) print('lg_fde_std: ', lg_fde_std) print('------------------------------------------') lg_num = 10 traj_num = 2 ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, \ sg_ade_min, sg_ade_avg, sg_ade_std, \ lg_fde_min, lg_fde_avg, lg_fde_std = solver.all_evaluation(test_loader, lg_num=lg_num, traj_num=traj_num, generate_heat=gh) print('lg_num: ', lg_num, ' // traj_num: ', traj_num) print('ade min: ', ade_min) print('ade avg: ', ade_avg) print('ade std: ', ade_std) print('fde min: ', fde_min) print('fde avg: ', fde_avg) print('fde std: ', fde_std) print('sg_ade_min: ', sg_ade_min) print('sg_ade_avg: ', sg_ade_avg) print('sg_ade_std: ', sg_ade_std) print('lg_fde_min: ', lg_fde_min) print('lg_fde_avg: ', lg_fde_avg) print('lg_fde_std: ', lg_fde_std) print('------------------------------------------') ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, \ sg_ade_min, sg_ade_avg, sg_ade_std, \ lg_fde_min, lg_fde_avg, lg_fde_std = solver.evaluate_dist(test_loader, loss=False) print('ade min: ', ade_min) print('ade avg: ', ade_avg) print('ade std: ', ade_std) print('fde min: ', fde_min) print('fde avg: ', fde_avg) print('fde std: ', fde_std) print('sg_ade_min: ', sg_ade_min) print('sg_ade_avg: ', sg_ade_avg) print('sg_ade_std: ', sg_ade_std) print('lg_fde_min: ', lg_fde_min) print('lg_fde_avg: ', lg_fde_avg) print('lg_fde_std: ', lg_fde_std) print('------------------------------------------') else: solver = Solver(args) solver.train()