def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    # global TrajectoryGenerator, TrajectoryDiscriminator
    # if args.D_type == 'rnn':
    #     print("Default Social GAN")
    #     from sgan.models import TrajectoryGenerator, TrajectoryDiscriminator
    # # elif args.GAN_type == 'simple_rnn':
    # #     print("Default Social GAN")
    # #     from sgan.rnn_models import TrajectoryGenerator, TrajectoryDiscriminator
    # else:
    #     print("Feedforward GAN")
    #     from sgan.ffd_models import TrajectoryGenerator, TrajectoryDiscriminator

    for path in paths:
        checkpoint = torch.load(path)
        generator = get_generator(checkpoint, best=0)
        _args = AttrDict(checkpoint['args'])
        print(_args)
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        ade, fde = evaluate(_args, loader, generator, args.num_samples)
        # result_str = '\n GAN_type: {}, Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f} \n'.format(
        #              _args.GAN_type, _args.dataset_name, _args.pred_len, ade, fde)
        result_str = 'D_type: {}, Dataset: {}, Loss Weight: {:.2f} Pred Len: {}, ADE: {:.2f}, FDE: {:.2f} Samples: {} \n \n'.format(
            args.D_type, _args.dataset_name, _args.l2_loss_weight,
            _args.pred_len, ade, fde, args.num_samples)

        print(result_str)
        with open(args.dest, "a") as myfile:
            myfile.write(result_str)

        generator = get_generator(checkpoint, best=1)
        _args = AttrDict(checkpoint['args'])
        # print(_args)
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        ade, fde = evaluate(_args, loader, generator, args.num_samples)
        # result_str = '\n GAN_type: {}, Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f} \n'.format(
        #              _args.GAN_type, _args.dataset_name, _args.pred_len, ade, fde)
        result_str = 'BEST D_type: {}, Dataset: {}, Loss Weight: {:.2f} Pred Len: {}, ADE: {:.2f}, FDE: {:.2f} Samples: {} \n \n'.format(
            args.D_type, _args.dataset_name, _args.l2_loss_weight,
            _args.pred_len, ade, fde, args.num_samples)

        print(result_str)
        with open(args.dest, "a") as myfile:
            myfile.write(result_str)
示例#2
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [
            os.path.join(args.model_path, file_) for file_ in filenames
        ]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)
        generator = get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        
        ade, fde, trajs, times = evaluate(_args, loader, generator, args.num_samples)
        
        print (times, np.mean(times))
        
        print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            _args.dataset_name, _args.pred_len, ade, fde))
        with open("trajs_dumped/" + _args.dataset_name + "_" + args.dset_type + "_trajs.pkl", 'wb') as f:
            pickle.dump(trajs, f)
        print ("trajs dumped at ", _args.dataset_name + "_" + args.dset_type + "_trajs.pkl")
示例#3
0
 def get_models():
     fast_load = st.cache(torch.load, ignore_hash=True)
     checkpoint_gen = fast_load(paths[1])
     checkpoint_dis = fast_load(paths[0])
     _args = dict([('clipping_threshold_d', 0), ('obs_len', 10),
                   ('batch_norm', False), ('timing', 0),
                   ('checkpoint_name', 'gan_test'),
                   ('num_samples_check', 5000), ('mlp_dim', 64),
                   ('use_gpu', 1), ('encoder_h_dim_d', 16),
                   ('num_epochs', 900), ('restore_from_checkpoint', 1),
                   ('g_learning_rate', 0.0005), ('pred_len', 20),
                   ('neighborhood_size', 2.0), ('delim', 'tab'),
                   ('d_learning_rate', 0.0002), ('d_steps', 2),
                   ('pool_every_timestep', False),
                   ('checkpoint_start_from', None), ('embedding_dim', 16),
                   ('d_type', 'local'), ('grid_size', 8), ('dropout', 0.0),
                   ('batch_size', 4), ('l2_loss_weight', 1.0),
                   ('encoder_h_dim_g', 16), ('print_every', 10),
                   ('best_k', 10), ('num_layers', 1), ('skip', 1),
                   ('bottleneck_dim', 32), ('noise_type', 'gaussian'),
                   ('clipping_threshold_g', 1.5), ('decoder_h_dim_g', 32),
                   ('gpu_num', '0'), ('loader_num_workers', 4),
                   ('pooling_type', 'pool_net'), ('noise_dim', (20, )),
                   ('g_steps', 1), ('checkpoint_every', 50),
                   ('noise_mix_type', 'global'), ('num_iterations', 80000)])
     _args = AttrDict(_args)
     generator = get_generator(_args, checkpoint_gen)
     discriminator = get_discriminator(_args, checkpoint_dis)
     data_path = get_dset_path(args.dataset_name, args.dset_type)
     _, loader = data_loader(_args, data_path)
     return _args, generator, discriminator, data_path, loader
示例#4
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    for path in paths:
        # set model file
        path = '/home/xiaotongfeng/Data/Pyhton/sgan-master/models/sgan-models/eth_12_model.pt'
        print(path)

        # load model
        # checkpoint = torch.load(path)
        checkpoint = torch.load(path,
                                map_location=lambda storage, loc: storage)
        generator = get_generator(checkpoint)  # generator

        # get data set
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)
        dataset_path = '/home/yuanwang/Data/Pyhton/sgan-master/datasets/hotel/test'
        print(dataset_path)
        dset, loader = data_loader(_args, dataset_path)

        # evaluate
        _args.batch_size = 1
        num_samples = 1
        print(_args.batch_size)
        ade, fde = evaluate(_args, loader, generator, num_samples)
        print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            _args.dataset_name, _args.pred_len, ade, fde))
        '''
示例#5
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    count = 0
    for path in paths:
        model_name = path.split('/')[-1]
        print('Model: {}'.format(model_name))
        checkpoint = torch.load(path)
        generator = get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])

        print(_args.dataset_name)

        # if _args.dataset_name=="zara1":
        #     for k in _args:
        #         print(k,_args[k])

        # path = get_dset_path(_args.dataset_name, args.dset_type)

        path = os.path.join(args.dataset_dir, _args.dataset_name,
                            'test_sample')  # 10 files:0-9
        _, loader = data_loader(_args, path)
        ade, fde = evaluate(_args, loader, generator, args.num_samples)
        print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            _args.dataset_name, _args.pred_len, ade, fde))
        count += 1
示例#6
0
def main(args):

    path = args.model_path
    path2 = args.model_path2

    checkpoint = torch.load(path)
    generator = get_generator(checkpoint)
    checkpoint2 = torch.load(path2)
    generator2 = get_generator2(checkpoint2)

    _args = AttrDict(checkpoint['args'])
    path = get_dset_path(_args.dataset_name, args.dset_type)
    _, loader = data_loader(_args, path)

    ade, fde, trajs, times = evaluate(_args, loader, generator, generator2,
                                      args.num_samples)

    print(times, np.mean(times))

    print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
        _args.dataset_name, _args.pred_len, ade, fde))
    with open(
            "trajs_dumped/" + _args.dataset_name + "_" + args.dset_type +
            "_trajs.pkl", 'wb') as f:
        pickle.dump(trajs, f)
    print("trajs dumped at ",
          _args.dataset_name + "_" + args.dset_type + "_trajs.pkl")
示例#7
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)
        generator = get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        evaluate(_args, loader, generator, args.num_samples)
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)
        generator = get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        ade, fde = evaluate(_args, loader, generator, args.num_samples)
        print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            _args.dataset_name, _args.pred_len, ade, fde))
示例#9
0
def get_eth_gan_generator(load_mode="CUDA"):
    """

    :return:
    """
    model_path = "../../sgan/models/sgan-models/eth_12_model.pt"
    if load_mode == "CUDA":
        checkpoint = torch.load(model_path)
    if load_mode == "CPU":
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
    global torch_mode
    torch_mode = load_mode
    generator = get_generator(checkpoint)
    _args = AttrDict(checkpoint['args'])
    path = get_dset_path(_args.dataset_name, 'test')

    _, loader = data_loader(_args, path)
    ade, fde = evaluate(_args, loader, generator, 20)
    return generator
示例#10
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)

        generator = get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        ade, fde, count_left, count_mid, count_right = evaluate(
            _args, loader, generator, args.num_samples)
        print(path, '\n', 'ADE: {:.2f}, FDE: {:.2f}'.format(ade, fde))
        print("count_left: ", count_left, " count_mid: ", count_mid,
              "  count_right: ", count_right)
示例#11
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)
        print("args for model_path", args.model_path, ":")
        generator = get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        ade, fde = evaluate(_args, loader, generator, args.num_samples)
        if "mtm" in args.model_path:
            ade, fde = ade * 110.0, fde * 110.0  # converting from lat-lon to kms
            print("ade and fde in kms")
        print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            _args.dataset_name, _args.pred_len, ade, fde))
示例#12
0
def main(args):



    checkpoint = torch.load(args.pt_name+'.pt')
    checkpoint_Gloss = checkpoint['G_losses']
    checkpoint_Dloss = checkpoint['D_losses']
    _args = AttrDict(checkpoint['args'])

    generatorSO = get_generatorSO(checkpoint)
    generatorST = get_generatorST(checkpoint)
    discriminator = get_dicriminator(checkpoint)

    path = get_dset_path(_args.dataset_name, args.dset_type)
    _, loader = data_loader(_args, path)

    batch_final = ()
    for batch in loader:
        batch = [tensor.cuda() for tensor in batch]
        batch_final = batch
    
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
             non_linear_ped, loss_mask, seq_start_end) = batch
    #ipdb.set_trace()
    noise_input, noise_shape = generatorSO(obs_traj, obs_traj_rel, seq_start_end)
    z_noise = MALA_corrected_sampler(generatorST, discriminator, _args, noise_shape, noise_input, seq_start_end, obs_traj, obs_traj_rel)
    decoder_h = torch.cat([noise_input, z_noise], dim=1)
    decoder_h = torch.unsqueeze(decoder_h, 0)
    generator_out = generatorST(decoder_h, seq_start_end, obs_traj, obs_traj_rel)
    pred_traj_fake_rel = generator_out
    pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])
    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)
    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    
    plot_G_losses(args, checkpoint_Gloss)
    plot_D_losses(args, checkpoint_Dloss)
    plot_net_structure(args, generatorSO, generatorST, discriminator, obs_traj, obs_traj_rel, seq_start_end, decoder_h, traj_fake, traj_fake_rel)
示例#13
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)
        generator = get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        datapath = get_dset_path(
            _args.dataset_name,
            args.dset_type)  # convert model path to dataset path
        _, loader = data_loader(_args,
                                datapath,
                                shuffle=False,
                                phase='testing',
                                split=args.dset_type)
        ade, fde = evaluate(_args, loader, generator, args.num_samples, path)
        print('Dataset: {}, Pred Len: {}, ADE: {:.3f}, FDE: {:.3f}'.format(
            _args.dataset_name, _args.pred_len * _args.skip, ade, fde))
示例#14
0
def main(args):
    if os.path.isdir(args.model_path):
        file_ = os.listdir(args.model_path)
        print(file_)
        path = os.path.join(args.model_path, file_[0])
    else:
        path = args.model_path
    checkpoint = torch.load(path)
    _args = dict([('clipping_threshold_d',0), ('obs_len',10), ('batch_norm',False), ('timing',0),
             ('checkpoint_name','gan_test'), ('num_samples_check',5000), ('mlp_dim',64), ('use_gpu',1), ('encoder_h_dim_d',16),
             ('num_epochs',900), ('restore_from_checkpoint',1), ('g_learning_rate',0.0005), ('pred_len',20), ('neighborhood_size',2.0),
             ('delim','tab'), ('d_learning_rate',0.0002), ('d_steps',2), ('pool_every_timestep', False), ('checkpoint_start_from', None),
             ('embedding_dim',16), ('d_type','local'), ('grid_size',8), ('dropout',0.0), ('batch_size',4), ('l2_loss_weight',1.0),
             ('encoder_h_dim_g',16), ('print_every',10), ('best_k',10), ('num_layers',1), ('skip',1), ('bottleneck_dim',32), ('noise_type','gaussian'),
             ('clipping_threshold_g',1.5), ('decoder_h_dim_g',32), ('gpu_num','0'), ('loader_num_workers',4), ('pooling_type','pool_net'),
             ('noise_dim',(20,)),('g_steps',1), ('checkpoint_every',50), ('noise_mix_type','global'), ('num_iterations',80000)])
    _args = AttrDict(_args)
    generator = get_generator(_args,checkpoint)
    data_path = get_dset_path(args.dataset_name, args.dset_type)
    _, loader = data_loader(_args, data_path)
    ade, fde = evaluate(_args, loader, generator, args.num_samples)
    print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            args.dataset_name, _args.pred_len, ade, fde))
示例#15
0
文件: SEEM-test.py 项目: ACoTAI/CODE
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)
        #TODO:gso&gst
        generatorSO = get_generatorSO(checkpoint)
        generatorST = get_generatorST(checkpoint)
        discriminator = get_dicriminator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)
        _, loader = data_loader(_args, path)
        ade, fde, count_left, count_mid, count_right = evaluate(
            _args, loader, generatorSO, generatorST, discriminator,
            args.num_samples)
        print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            _args.dataset_name, _args.pred_len, ade, fde))
        print("count_left: ", count_left, "count_right: ", count_right)
示例#16
0
def main(args):
    if os.path.isdir(args.model_path):
        filenames = os.listdir(args.model_path)
        filenames.sort()
        paths = [os.path.join(args.model_path, file_) for file_ in filenames]
    else:
        paths = [args.model_path]

    for path in paths:
        checkpoint = torch.load(path)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.dset_type)

        generator = get_generator(checkpoint, args_=args)

        _, loader = data_loader(_args, path)

        ade, fde, trajs = evaluate(_args, loader, generator)

        print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
            _args.dataset_name, _args.pred_len, ade, fde))
        if _args.dataset_name.split("/")[0] == "split_moving":
            path = "trajs_dumped/" + "/".join(
                _args.dataset_name.split("/")[:-1])
            pathlib.Path(path).mkdir(parents=True, exist_ok=True)
        with open(
                "trajs_dumped/" +
                args.model_path.split("/")[-1].split(".")[0] + "_" +
                args.dset_type + "_trajs.pkl", 'wb+') as f:
            pickle.dump(trajs, f)
        print(
            "trajs dumped at ",
            args.model_path.split("/")[-1].split(".")[0] + "_" +
            args.dset_type + "_trajs.pkl")

    return ade.item(), fde.item()
def main(args):
    if args.mode == 'training':
        args.checkpoint_every = 100
        args.teacher_name = "default"
        args.restore_from_checkpoint = 0
        #args.l2_loss_weight = 0.0
        args.rollout_steps = 1
        args.rollout_rate = 1
        args.rollout_method = 'sgd'
        #print("HHHH"+str(args.l2_loss_weight))

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    global TrajectoryGenerator, TrajectoryDiscriminator
    if args.GAN_type == 'rnn':
        print("Default Social GAN")
        from sgan.models import TrajectoryGenerator, TrajectoryDiscriminator
    elif args.GAN_type == 'simple_rnn':
        print("Default Social GAN")
        from sgan.rnn_models import TrajectoryGenerator, TrajectoryDiscriminator
    else:
        print("Feedforward GAN")
        if (args.Encoder_type == 'MLP' and args.Decoder_type == 'MLP'):
            from sgan.cgs_integrated_model.cgs_ffd_models_E_MLP_D_MLP import TrajectoryGenerator, TrajectoryDiscriminator
        if (args.Encoder_type == 'MLP' and args.Decoder_type == 'LSTM'):
            from sgan.cgs_integrated_model.cgs_ffd_models_E_MLP_D_LSTM import TrajectoryGenerator, TrajectoryDiscriminator
        if (args.Encoder_type == 'LSTM' and args.Decoder_type == 'MLP'):
            from sgan.cgs_integrated_model.cgs_ffd_models_E_LSTM_D_MLP import TrajectoryGenerator, TrajectoryDiscriminator
        if (args.Encoder_type == 'LSTM' and args.Decoder_type == 'LSTM'):
            from sgan.cgs_integrated_model.cgs_ffd_models_E_LSTM_D_LSTM import TrajectoryGenerator, TrajectoryDiscriminator

    #image_dir = 'images/' + 'curve_5_traj_l2_0.5'
    #image_dir = 'images/5trajectory/' + 'havingplots'+ '2-layers-EN-' + args.Encoder_type +  '-DE-20-layers-' + args.Decoder_type + '-L2_' + str(args.l2_loss_weight)

    image_dir = 'images/' + str(args.dataset_name) + \
                '_EN_' + args.Encoder_type + '(' + str(*[args.mlp_encoder_layers if args.Encoder_type == 'MLP' else 1]) + ')' + \
                '_DE_' + args.Decoder_type + '(' + str(*[args.mlp_decoder_layers if args.Decoder_type == 'MLP' else 1]) + ')' + \
                '_DIS_' + args.GAN_type.upper() + '(' + str(args.mlp_discriminator_layers) + ')' + \
                '_L2_Weight' + '(' + str(args.l2_loss_weight) + ')'

    print("Image Dir: ", image_dir)
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm,
        num_mlp_decoder_layers=args.mlp_decoder_layers,
        num_mlp_encoder_layers=args.mlp_encoder_layers)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type,
        mlp_discriminator_layers=args.mlp_discriminator_layers,
        num_mlp_encoder_layers=args.mlp_encoder_layers)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    # build teacher
    print("[!] teacher_name: ", args.teacher_name)

    if args.teacher_name == 'default':
        teacher = None
    elif args.teacher_name == 'gpurollout':
        from teacher_gpu_rollout_torch import TeacherGPURollout
        teacher = TeacherGPURollout(args)
        teacher.set_env(discriminator, generator)
        print("GPU Rollout Teacher")
    else:
        raise NotImplementedError

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr=args.d_learning_rate)

    # # Create D optimizer.
    # self.d_optim = tf.train.AdamOptimizer(self.disc_LR*config.D_LR, beta1=config.beta1)
    # # Compute the gradients for a list of variables.
    # self.grads_d_and_vars = self.d_optim.compute_gradients(self.d_loss, var_list=self.d_vars)
    # self.grad_default_real = self.d_optim.compute_gradients(self.d_loss_real, var_list=inputs)
    # # Ask the optimizer to apply the capped gradients.
    # self.update_d = self.d_optim.apply_gradients(self.grads_d_and_vars)
    # ## Get Saliency Map - Teacher
    # self.saliency_map = tf.gradients(self.d_loss, self.inputs)[0]

    # ###### G Optimizer ######
    # # Create G optimizer.
    # self.g_optim = tf.train.AdamOptimizer(config.learning_rate*config.G_LR, beta1=config.beta1)

    # # Compute the gradients for a list of variables.
    # ## With respect to Generator Weights - AutoLoss
    # self.grad_default = self.g_optim.compute_gradients(self.g_loss, var_list=[self.G, self.g_vars])
    # ## With Respect to Images given to D - Teacher
    # # self.grad_default = g_optim.compute_gradients(self.g_loss, var_list=)
    # if config.teacher_name == 'default':
    # self.optimal_grad = self.grad_default[0][0]
    # self.optimal_batch = self.G - self.optimal_grad
    # else:
    # self.optimal_grad, self.optimal_batch = self.teacher.build_teacher(self.G, self.D_, self.grad_default[0][0], self.inputs)

    # # Ask the optimizer to apply the manipulated gradients.
    # grads_collected = tf.gradients(self.G, self.g_vars, self.optimal_grad)
    # grads_and_vars_collected = list(zip(grads_collected, self.g_vars))

    # self.g_teach = self.g_optim.apply_gradients(grads_and_vars_collected)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    fig = plt.figure()
    ax = fig.add_axes([0.1, 0.1, 0.75, 0.75])

    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:

            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:

                if args.mode != 'testing':
                    step_type = 'd'
                    losses_d = discriminator_step(args, batch, generator,
                                                  discriminator, d_loss_fn,
                                                  optimizer_d, teacher,
                                                  args.mode)
                    checkpoint['norm_d'].append(
                        get_total_norm(discriminator.parameters()))

                d_steps_left -= 1

            elif g_steps_left > 0:

                if args.mode != 'testing':
                    step_type = 'g'
                    losses_g = generator_step(args, batch, generator,
                                              discriminator, g_loss_fn,
                                              optimizer_g, args.mode)
                    checkpoint['norm_g'].append(
                        get_total_norm(generator.parameters()))

                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0 and args.mode != 'testing':
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

                # # Check stats on the validation set
                # logger.info('Checking stats on val ...')
                # metrics_val = check_accuracy(
                #     args, val_loader, generator, discriminator, d_loss_fn
                # )
                # logger.info('Checking stats on train ...')
                # metrics_train = check_accuracy(
                #     args, train_loader, generator, discriminator,
                #     d_loss_fn, limit=True
                # )

                # for k, v in sorted(metrics_val.items()):
                #     logger.info('  [val] {}: {:.3f}'.format(k, v))
                #     checkpoint['metrics_val'][k].append(v)
                # for k, v in sorted(metrics_train.items()):
                #     logger.info('  [train] {}: {:.3f}'.format(k, v))
                #     checkpoint['metrics_train'][k].append(v)

                # min_ade = min(checkpoint['metrics_val']['ade'])
                # min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                # if metrics_val['ade'] == min_ade:
                #     logger.info('New low for avg_disp_error')
                #     checkpoint['best_t'] = t
                #     checkpoint['g_best_state'] = generator.state_dict()
                #     checkpoint['d_best_state'] = discriminator.state_dict()

                # if metrics_val['ade_nl'] == min_ade_nl:
                #     logger.info('New low for avg_disp_error_nl')
                #     checkpoint['best_t_nl'] = t
                #     checkpoint['g_best_nl_state'] = generator.state_dict()
                #     checkpoint['d_best_nl_state'] = discriminator.state_dict()

            if t % 50 == 0:
                # save = False
                # if t == 160:
                # save = True
                # print(t)
                plot_trajectory(fig,
                                ax,
                                args,
                                val_loader,
                                generator,
                                teacher,
                                args.mode,
                                t,
                                save=True,
                                image_dir=image_dir)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                print("Iteration: ", t)
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
示例#18
0
def extract_our_and_sgan_preds(dataset_name, hyperparams, args, data_precondition='all'):
    print('At %s dataset' % dataset_name)

    ### SGAN LOADING ###
    sgan_model_path = os.path.join(args.sgan_models_path, '_'.join([dataset_name, '12', 'model.pt']))

    checkpoint = torch.load(sgan_model_path, map_location='cpu')
    generator = get_generator(checkpoint)
    _args = AttrDict(checkpoint['args'])
    path = get_dset_path(_args.dataset_name, args.sgan_dset_type)
    print('Evaluating', sgan_model_path, 'on', _args.dataset_name, args.sgan_dset_type)

    _, sgan_data_loader = data_loader(_args, path)

    ### OUR METHOD LOADING ###
    data_dir = '../sgan-dataset/data'
    eval_data_dict_name = '%s_test.pkl' % dataset_name
    log_dir = '../sgan-dataset/logs/%s' % dataset_name
    have_our_model = False
    if os.path.isdir(log_dir):
        have_our_model = True

        trained_model_dir = os.path.join(log_dir, get_our_model_dir(dataset_name))
        eval_data_path = os.path.join(data_dir, eval_data_dict_name)
        with open(eval_data_path, 'rb') as f:
            eval_data_dict = pickle.load(f, encoding='latin1')
        eval_dt = eval_data_dict['dt']
        print('Loaded evaluation data from %s, eval_dt = %.2f' % (eval_data_path, eval_dt))

        # Loading weights from the trained model.
        specific_hyperparams = get_model_hyperparams(args, dataset_name)
        model_registrar = ModelRegistrar(trained_model_dir, args.device)
        model_registrar.load_models(specific_hyperparams['best_iter'])

        for key in eval_data_dict['input_dict'].keys():
            if isinstance(key, STGNode):
                random_node = key
                break

        hyperparams['state_dim'] = eval_data_dict['input_dict'][random_node].shape[2]
        hyperparams['pred_dim'] = len(eval_data_dict['pred_indices'])
        hyperparams['pred_indices'] = eval_data_dict['pred_indices']
        hyperparams['dynamic_edges'] = args.dynamic_edges
        hyperparams['edge_state_combine_method'] = specific_hyperparams['edge_state_combine_method']
        hyperparams['edge_influence_combine_method'] = specific_hyperparams['edge_influence_combine_method']
        hyperparams['nodes_standardization'] = eval_data_dict['nodes_standardization']
        hyperparams['labels_standardization'] = eval_data_dict['labels_standardization']
        hyperparams['edge_radius'] = args.edge_radius

        eval_hyperparams = copy.deepcopy(hyperparams)
        eval_hyperparams['nodes_standardization'] = eval_data_dict["nodes_standardization"]
        eval_hyperparams['labels_standardization'] = eval_data_dict["labels_standardization"]

        kwargs_dict = {'dynamic_edges': hyperparams['dynamic_edges'],
                       'edge_state_combine_method': hyperparams['edge_state_combine_method'],
                       'edge_influence_combine_method': hyperparams['edge_influence_combine_method']}


        print('-------------------------')
        print('| EVALUATION PARAMETERS |')
        print('-------------------------')
        print('| checking: %s' % data_precondition)
        print('| device: %s' % args.device)
        print('| eval_device: %s' % args.eval_device)
        print('| edge_radius: %s' % hyperparams['edge_radius'])
        print('| EE state_combine_method: %s' % hyperparams['edge_state_combine_method'])
        print('| EIE scheme: %s' % hyperparams['edge_influence_combine_method'])
        print('| dynamic_edges: %s' % hyperparams['dynamic_edges'])
        print('| edge_addition_filter: %s' % args.edge_addition_filter)
        print('| edge_removal_filter: %s' % args.edge_removal_filter)
        print('| MHL: %s' % hyperparams['minimum_history_length'])
        print('| PH: %s' % hyperparams['prediction_horizon'])
        print('| # Samples: %s' % args.num_samples)
        print('| # Runs: %s' % args.num_runs)
        print('-------------------------')

        # It is important that eval_stg uses the same model_registrar as
        # the stg being trained, otherwise you're just repeatedly evaluating
        # randomly-initialized weights!
        eval_stg = SpatioTemporalGraphCVAEModel(None, model_registrar,
                                                eval_hyperparams, kwargs_dict,
                                                None, args.eval_device)
        print('Created evaluation STG model.')

        eval_agg_scene_graph = create_batch_scene_graph(eval_data_dict['input_dict'],
                                                        float(hyperparams['edge_radius']),
                                                        use_old_method=(args.dynamic_edges=='no'))
        print('Created aggregate evaluation scene graph.')

        if args.dynamic_edges == 'yes':
            eval_agg_scene_graph.compute_edge_scaling(args.edge_addition_filter, args.edge_removal_filter)
            eval_data_dict['input_dict']['edge_scaling_mask'] = eval_agg_scene_graph.edge_scaling_mask
            print('Computed edge scaling for the evaluation scene graph.')

        eval_stg.set_scene_graph(eval_agg_scene_graph)
        print('Set the aggregate scene graph.')

        eval_stg.set_annealing_params()

    print('About to begin evaluation computation for %s.' % dataset_name)
    with torch.no_grad():
        eval_inputs, _ = sample_inputs_and_labels(eval_data_dict, device=args.eval_device)

        sgan_preds_list = list()
        sgan_gt_list = list()
        our_preds_list = list()
        our_preds_most_likely_list = list()

        (obs_traj, pred_traj_gt, obs_traj_rel,
         seq_start_end, data_ids, t_predicts) = get_sgan_data_format(eval_inputs, what_to_check=data_precondition)

        num_runs = args.num_runs
        print('num_runs, seq_start_end.shape[0]', args.num_runs, seq_start_end.shape[0])
        if args.num_runs > seq_start_end.shape[0]:
            print('num_runs (%d) > seq_start_end.shape[0] (%d), reducing num_runs to match.' % (num_runs, seq_start_end.shape[0]))
            num_runs = seq_start_end.shape[0]

        samples_list = list()
        for _ in range(args.num_samples):
            pred_traj_fake_rel = generator(
                obs_traj, obs_traj_rel, seq_start_end
            )
            pred_traj_fake = relative_to_abs(
                pred_traj_fake_rel, obs_traj[-1]
            )

            samples_list.append(pred_traj_fake)

        random_scene_idxs = np.random.choice(seq_start_end.shape[0],
                                             size=(num_runs,),
                                             replace=False).astype(int)

        sgan_history = defaultdict(dict)
        for run in range(num_runs):
            random_scene_idx = random_scene_idxs[run]
            seq_idx_range = seq_start_end[random_scene_idx]

            agent_preds = dict()
            agent_gt = dict()
            for seq_agent in range(seq_idx_range[0], seq_idx_range[1]):
                agent_preds[seq_agent] = torch.stack([x[:, seq_agent] for x in samples_list], dim=0)
                agent_gt[seq_agent] = torch.unsqueeze(pred_traj_gt[:, seq_agent], dim=0)
                sgan_history[run][seq_agent] = obs_traj[:, seq_agent]

            sgan_preds_list.append(agent_preds)
            sgan_gt_list.append(agent_gt)

        print('Done running SGAN')

        if have_our_model:
            sgan_our_agent_map = dict()
            our_sgan_agent_map = dict()
            for run in range(num_runs):
                print('At our run number', run)
                random_scene_idx = random_scene_idxs[run]
                data_id = data_ids[random_scene_idx]
                t_predict = t_predicts[random_scene_idx] - 1

                curr_inputs = {k: v[[data_id]] for k, v in eval_inputs.items()}
                curr_inputs['traj_lengths'] = torch.tensor([t_predict])

                with torch.no_grad():
                    preds_dict_most_likely = eval_stg.predict(curr_inputs, hyperparams['prediction_horizon'], args.num_samples, most_likely=True)
                    preds_dict_full = eval_stg.predict(curr_inputs, hyperparams['prediction_horizon'], args.num_samples, most_likely=False)

                our_preds_most_likely_list.append(preds_dict_most_likely)
                our_preds_list.append(preds_dict_full)

                for node, value in curr_inputs.items():
                    if isinstance(node, STGNode) and np.any(value[0, t_predict]):
                        curr_prev = value[0, t_predict+1-8 : t_predict+1]
                        for seq_agent, sgan_val in sgan_history[run].items():
                            if torch.norm(curr_prev[:, :2] - sgan_val) < 1e-4:
                                sgan_our_agent_map['%d/%d' % (run, seq_agent)] = node
                                our_sgan_agent_map['%d/%s' % (run, str(node))] = '%d/%d' % (run, seq_agent)

            print('Done running Our Method')

        # Pruning values that aren't in either.
        for run in range(num_runs):
            agent_preds = sgan_preds_list[run]
            agent_gt = sgan_gt_list[run]

            new_agent_preds = dict()
            new_agent_gts = dict()
            for agent in agent_preds.keys():
                run_agent_key = '%d/%d' % (run, agent)
                if run_agent_key in sgan_our_agent_map:
                    new_agent_preds[sgan_our_agent_map[run_agent_key]] = agent_preds[agent]
                    new_agent_gts[sgan_our_agent_map[run_agent_key]] = agent_gt[agent]

            sgan_preds_list[run] = new_agent_preds
            sgan_gt_list[run] = new_agent_gts

        for run in range(num_runs):
            agent_preds_ml = our_preds_most_likely_list[run]
            agent_preds_full = our_preds_list[run]

            new_agent_preds = dict()
            new_agent_preds_full = dict()
            for node in [x for x in agent_preds_ml.keys() if x.endswith('/y')]:
                node_key_list = node.split('/')
                node_obj = STGNode(node_key_list[1], node_key_list[0])
                node_obj_key = '%d/%s' % (run, str(node_obj))
                if node_obj_key in our_sgan_agent_map:
                    new_agent_preds[node_obj] = agent_preds_ml[node]
                    new_agent_preds_full[node_obj] = agent_preds_full[node]

            our_preds_most_likely_list[run] = new_agent_preds
            our_preds_list[run] = new_agent_preds_full

        # Guaranteeing the number of agents are the same.
        for run in range(num_runs):
            assert list_compare(our_preds_most_likely_list[run].keys(), sgan_preds_list[run].keys())
            assert list_compare(our_preds_list[run].keys(), sgan_preds_list[run].keys())
            assert list_compare(our_preds_most_likely_list[run].keys(), our_preds_list[run].keys())
            assert list_compare(sgan_preds_list[run].keys(), sgan_gt_list[run].keys())

    return (our_preds_most_likely_list, our_preds_list,
            sgan_preds_list, sgan_gt_list, eval_inputs, eval_data_dict,
            data_ids, t_predicts, random_scene_idxs, num_runs)
示例#19
0
def objective(trial):

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    discriminator_wight = trial.suggest_categorical('discriminator_wight',
                                                    [0, 1])
    optim_name = trial.suggest_categorical('optim_name',
                                           ['Adam', 'Adamax', 'RMSprop'])

    # args.batch_size = trial.suggest_categorical('batch_size', [32, 64, 128])
    args.dropout = trial.suggest_categorical('drop_out', [0, 0.2, 0.5])
    args.batch_norm = trial.suggest_categorical('batch_norm', [0, 1])

    N_TRAIN_EXAMPLES = args.batch_size * 30
    N_VALID_EXAMPLES = args.batch_size * 10

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm,
        use_cuda=args.use_gpu)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(obs_len=args.obs_len,
                                            pred_len=args.pred_len,
                                            embedding_dim=args.embedding_dim,
                                            h_dim=args.encoder_h_dim_d,
                                            mlp_dim=args.mlp_dim,
                                            num_layers=args.num_layers,
                                            dropout=args.dropout,
                                            batch_norm=args.batch_norm,
                                            d_type=args.d_type,
                                            use_cuda=args.use_gpu)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    if optim_name == 'Adam':
        optimizer_g = optim.Adam([{
            'params': generator.parameters(),
            'initial_lr': args.g_learning_rate
        }],
                                 lr=args.g_learning_rate)
        optimizer_d = optim.Adam([{
            'params': discriminator.parameters(),
            'initial_lr': args.d_learning_rate
        }],
                                 lr=args.d_learning_rate)

    elif optim_name == 'Adamax':
        optimizer_g = optim.Adamax([{
            'params': generator.parameters(),
            'initial_lr': args.g_learning_rate
        }],
                                   lr=args.g_learning_rate)
        optimizer_d = optim.Adamax([{
            'params': discriminator.parameters(),
            'initial_lr': args.d_learning_rate
        }],
                                   lr=args.d_learning_rate)
    else:
        optimizer_g = optim.RMSprop([{
            'params': generator.parameters(),
            'initial_lr': args.g_learning_rate
        }],
                                    lr=args.g_learning_rate)
        optimizer_d = optim.RMSprop([{
            'params': discriminator.parameters(),
            'initial_lr': args.d_learning_rate
        }],
                                    lr=args.d_learning_rate)

    scheduler_g = optim.lr_scheduler.StepLR(optimizer_g,
                                            step_size=100,
                                            gamma=0.5,
                                            last_epoch=-1)
    scheduler_d = optim.lr_scheduler.StepLR(optimizer_d,
                                            step_size=100,
                                            gamma=0.5,
                                            last_epoch=-1)

    t, epoch = 0, 0

    while t < 50:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps

        for batch_idx, batch in enumerate(train_loader):

            # Limiting training utils for faster epochs.
            if batch_idx * args.batch_size >= N_TRAIN_EXAMPLES:
                break

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)

                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g, discriminator_wight)

                g_steps_left -= 1

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break

        scheduler_g.step()
        scheduler_d.step()

        metrics_val = check_accuracy(args, val_loader, generator,
                                     discriminator, d_loss_fn,
                                     N_VALID_EXAMPLES)

        ade = metrics_val['ade']

        trial.report(ade, t)

    return ade
示例#20
0
def main(evalArgs):
    if os.path.isdir(evalArgs.model_path):
        filenames = os.listdir(evalArgs.model_path)
        filenames.sort()
        paths = [
            os.path.join(evalArgs.model_path, file_) for file_ in filenames
        ]
    else:
        paths = [evalArgs.model_path]

    totalNumOfPedestrians = 0

    ADE8, FDE8, ADE12, FDE12 = {}, {}, {}, {}

    for path in paths:
        print('\nStarting with evaluation of model:', path)

        if evalArgs.use_gpu:
            checkpoint = torch.load(path)
        else:
            checkpoint = torch.load(path, map_location='cpu')

        generator = get_generator(checkpoint, evalArgs)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, evalArgs.dset_type)
        _, loader = data_loader(_args, path)

        # Compute collision statistics for multiple thresholds
        #collisionThresholds = [0.05, 0.1, 0.2, 0.3, 0.5, 0.7, 1, 2]
        collisionThresholds = [0.1]
        for currCollisionThreshold in collisionThresholds:
            ade, fde, testSetStatistics, poolingStatistics, collisionStatistics = evaluate(
                _args, loader, generator, evalArgs.num_samples,
                currCollisionThreshold)
            print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
                _args.dataset_name, _args.pred_len, ade, fde))

            print('Collisions for threshold:', currCollisionThreshold)

        if (_args.pred_len == 8):
            ADE8[_args.dataset_name] = ade
            FDE8[_args.dataset_name] = fde
        elif (_args.pred_len == 12):
            ADE12[_args.dataset_name] = ade
            FDE12[_args.dataset_name] = fde
        else:
            print('Error while storing the evaluation result!')

        # name of directory to store the figures
        dirName = 'barCharts'

        if (evalArgs.showStatistics == 1):
            print('Test set statistics:', testSetStatistics)
            currNumOfScenes, pedestriansPerScene, currNumOfBatches = next(
                iter(testSetStatistics.values()))

            plt.clf()
            plt.bar(list(pedestriansPerScene.keys()),
                    pedestriansPerScene.values(),
                    color='g')
            plt.xlabel('Number of pedestrians')
            plt.ylabel('Number of situations')
            plt.xticks(range(max(pedestriansPerScene.keys()) + 2))
            plt.title('Dataset: {}, Pred Len: {}'.format(
                _args.dataset_name, _args.pred_len))
            plt.savefig(dirName +
                        '/howCrowded_Dataset_{}_PredictionLen_{}.png'.format(
                            _args.dataset_name, _args.pred_len))
            #plt.show()

            totalNumOfPedestrians += sum(
                k * v for k, v in pedestriansPerScene.items())

            if _args.pooling_type.lower() != 'none':
                print('Pooling vector statistics:', poolingStatistics)
                includedPedestrians, includedOtherPedestrians, includedSelf, ratioChosenAndClosest = poolingStatistics

                plt.clf()
                # histogram: x axis is % of included pedestrians, y axis is number of pooling vectors with that %
                plt.bar(list(includedPedestrians.keys()),
                        includedPedestrians.values(),
                        color='g',
                        width=0.02)
                plt.xlabel('% of included pedestrians')
                plt.ylabel('Number of pooling vectors')
                plt.title('Dataset: {}, Pred Len: {}'.format(
                    _args.dataset_name, _args.pred_len))
                plt.savefig(dirName +
                            '/percentIncluded_Dataset_{}_PredLen_{}.png'.
                            format(_args.dataset_name, _args.pred_len))
                #plt.show()

                plt.clf()
                plt.bar(list(includedOtherPedestrians.keys()),
                        includedOtherPedestrians.values(),
                        color='g',
                        width=0.02)
                plt.xlabel('% of included pedestrians (no self inclusions)')
                plt.ylabel('Number of pooling vectors')
                plt.title('Dataset: {}, Pred Len: {}'.format(
                    _args.dataset_name, _args.pred_len))
                plt.savefig(dirName +
                            '/percentIncludedOther_Dataset_{}_PredLen_{}.png'.
                            format(_args.dataset_name, _args.pred_len))
                #plt.show()

                plt.clf()
                plt.bar(list(includedSelf.keys()),
                        includedSelf.values(),
                        color='g',
                        width=0.02)
                plt.xlabel('% of self inclusions')
                plt.ylabel('Number of pooling vectors')
                plt.title('Dataset: {}, Pred Len: {}'.format(
                    _args.dataset_name, _args.pred_len))
                plt.savefig(dirName +
                            '/percentSelfInclusions_Dataset_{}_PredLen_{}.png'.
                            format(_args.dataset_name, _args.pred_len))
                #plt.show()

                plt.clf()
                plt.bar(list(ratioChosenAndClosest.keys()),
                        ratioChosenAndClosest.values(),
                        color='g',
                        width=0.02)
                plt.xlabel('Distance ratio between chosen and closest')
                plt.ylabel('Number of pooling vector values with that ratio')
                plt.title('Dataset: {}, Pred Len: {}'.format(
                    _args.dataset_name, _args.pred_len))
                plt.savefig(dirName +
                            '/chosenClosestRatio_Dataset_{}_PredLen_{}.png'.
                            format(_args.dataset_name, _args.pred_len))
                #plt.show()

                # same as ratio dict, just sums up y values starting from x = 1
                massRatioChosenAndClosest = collections.OrderedDict()
                massRatioChosenAndClosest[-1] = ratioChosenAndClosest[-1]
                acc = 0
                for currKey, currValue in sorted(
                        ratioChosenAndClosest.items())[1:]:
                    acc += currValue

                    massRatioChosenAndClosest[currKey] = acc

                plt.clf()
                # Interpretation: for a x value, how many pooling vector values come from pedestrians that are at most x times farther away than the closest pedestrian
                plt.bar(list(massRatioChosenAndClosest.keys()),
                        massRatioChosenAndClosest.values(),
                        color='g',
                        width=0.02)
                plt.xlabel('Distance ratio between chosen and closest')
                plt.ylabel(
                    'Pooling values with that ratio (sum from x=1 onwards)')
                plt.title('Dataset: {}, Pred Len: {}'.format(
                    _args.dataset_name, _args.pred_len))
                plt.savefig(dirName +
                            '/massChosenClosestRatio_Dataset_{}_PredLen_{}.png'
                            .format(_args.dataset_name, _args.pred_len))
                #plt.show()

            numOfCollisions, totalNumOfSituations, collisionSituations = next(
                iter(collisionStatistics.values()))
            print(
                'Total number of frames with collisions (all situations, all samples):',
                numOfCollisions)
            print(
                'Total number of situations (all samples, with and without collisions):',
                totalNumOfSituations)
            print(
                'Total number of situations with collisions (all samples): {}, that\'s {:.1%}'
                .format(len(collisionSituations),
                        len(collisionSituations) / totalNumOfSituations))

            # loops through and visualizes all situations for which a collision has been detected
            #for currSituation in collisionSituations:
            #obs_traj, pred_traj_fake, pred_traj_gt = currSituation
            #visualizeSituation(obs_traj, pred_traj_fake, pred_traj_gt)

            print('\n \n')

    destination = 'evalResults/ERROR/SETNAMEFOREVALUATIONMANUALLYHERE.pkl'
    with open(destination, 'wb') as f:
        pickle.dump((ADE8, FDE8, ADE12, FDE12), f)

    print('Evaluation is done.')
示例#21
0
def main(args):
    print(args)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    # train_path = get_dset_path(args.dataset_name, 'train')
    # val_path = get_dset_path(args.dataset_name, 'val')

    train_path= os.path.join(data_dir,args.dataset_name,'train_small') # 10 files:0-9
    val_path= os.path.join(data_dir,args.dataset_name,'val_small') # 5 files: 10-14

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch)
    )

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type,
        activation=args.d_activation # default: relu
    )

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(
        discriminator.parameters(), lr=args.d_learning_rate
    )

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters())
                )
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1, time.time() - t0
                    ))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    # logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    # logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

                ## log scalars
                for k, v in sorted(losses_d.items()):
                    writer.add_scalar("loss/{}".format(k), v, t)
                for k, v in sorted(losses_g.items()):
                    writer.add_scalar("loss/{}".format(k), v, t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(
                    args, val_loader, generator, discriminator, d_loss_fn
                )
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(
                    args, train_loader, generator, discriminator,
                    d_loss_fn, limit=True
                )

                for k, v in sorted(metrics_val.items()):
                    # logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    # logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                ## log scalars
                for k, v in sorted(metrics_val.items()):
                    writer.add_scalar("val/{}".format(k), v, t)
                for k, v in sorted(metrics_train.items()):
                    writer.add_scalar("train/{}".format(k), v, t)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                # checkpoint_path = os.path.join(
                #     args.output_dir, '{}_with_model_{:06d}.pt'.format(args.checkpoint_name,t)
                # )
                checkpoint_path = os.path.join(args.output_dir, '{}_with_mode.pt'.format(args.checkpoint_name))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items

                # checkpoint_path = os.path.join(
                #     args.output_dir, '{}_no_model_{:06d}.pt' .format(args.checkpoint_name,t))

                checkpoint_path = os.path.join(args.output_dir, '{}_no_model.pt' .format(args.checkpoint_name))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
示例#22
0
parser.add_argument('--checkpoint_start_from', default=None)
parser.add_argument('--restore_from_checkpoint', default=0,
                    type=int)  # default:1
parser.add_argument('--num_samples_check', default=5000, type=int)

# Misc
parser.add_argument('--use_gpu', default=1, type=int)  # 1: use_gpu
parser.add_argument('--timing', default=0, type=int)
parser.add_argument('--gpu_num', default="0", type=str)

args = parser.parse_args()

tmp_path = os.path.join(data_dir, '01.02.2016.DEN.at.GSW',
                        'tmp')  # 200 files:0-199
# tmp_path= os.path.join(data_dir,args.dataset_name,'train_sample') #
tmp_dset, tmp_loader = data_loader(args, tmp_path)

dataset_len = len(tmp_dset)
print(dataset_len)
iterations_per_epoch = dataset_len / 128 / args.d_steps
if args.num_epochs:
    args.num_iterations = int(iterations_per_epoch * args.num_epochs)

print(iterations_per_epoch)
print(args.num_iterations)

# traj_max=[]

# for batch in tmp_loader:
#     # batch = [tensor.cuda() for tensor in batch]
#     (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
示例#23
0
def collect_generated_samples(args,
                              generator1,
                              generator2,
                              data_dir,
                              data_set,
                              model_name,
                              selected_scene=None,
                              selected_batch=-1):
    num_samples = 10  # args.best_k
    _, loader = data_loader(args, data_dir, shuffle=False)

    with torch.no_grad():
        for b, batch in enumerate(loader):
            print('batch = {}'.format(b))
            batch = [tensor.cuda() for tensor in batch]
            if b != selected_batch and selected_batch != -1:
                continue

            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
             non_linear_ped, loss_mask, traj_frames, seq_start_end,
             seq_scene_ids) = batch

            list_data_files = sorted([
                get_dset_name(os.path.join(data_dir, _path).split("/")[-1])
                for _path in os.listdir(data_dir)
            ])
            seq_scenes = [list_data_files[num] for num in seq_scene_ids]

            photo_list, homography_list, annotated_points_list, scene_name_list, scene_information = [], [], [], [], {}
            for i, (start, end) in enumerate(seq_start_end):
                dataset_name = seq_scenes[i]
                path = get_path(dataset_name)
                reader = imageio.get_reader(get_sdd_dir(dataset_name, 'video'),
                                            'ffmpeg')
                annotated_points, h = get_homography_and_map(
                    dataset_name, "/world_points_boundary.npy")
                homography_list.append(h)
                annotated_points_list.append(annotated_points)
                scene_name_list.append(dataset_name)
                scene_information[dataset_name] = annotated_points

                start = start.item()
                (obs_len, batch_size, _) = obs_traj.size()
                frame = traj_frames[obs_len][start][0].item()
                photo = reader.get_data(int(frame))
                photo_list.append(photo)

            scene_name = np.unique(scene_name_list)
            if selected_scene != None and not (scene_name
                                               == selected_scene).all():
                print(selected_scene, ' is not in current batch ', scene_name)
                continue

            save_pickle(obs_traj, 'obs_traj', selected_scene, b, data_set,
                        model_name)
            save_pickle(pred_traj_gt, 'pred_traj_gt', selected_scene, b,
                        data_set, model_name)
            save_pickle(seq_start_end, 'seq_start_end', selected_scene, b,
                        data_set, model_name)

            save_pickle(homography_list, 'homography_list', selected_scene, b,
                        data_set, model_name)
            save_pickle(annotated_points_list, 'annotated_points_list',
                        selected_scene, b, data_set, model_name)
            save_pickle(photo_list, 'photo_list', selected_scene, b, data_set,
                        model_name)
            save_pickle(scene_name_list, 'scene_name_list', selected_scene, b,
                        data_set, model_name)
            save_pickle(scene_information, 'scene_information', selected_scene,
                        b, data_set, model_name)

            pred_traj_fake1_list, pred_traj_fake2_list = [], []

            for sample in range(num_samples):
                pred_traj_fake1, _ = get_trajectories(generator1, obs_traj,
                                                      obs_traj_rel,
                                                      seq_start_end,
                                                      pred_traj_gt,
                                                      seq_scene_ids, data_dir)
                pred_traj_fake2, _ = get_trajectories(generator2, obs_traj,
                                                      obs_traj_rel,
                                                      seq_start_end,
                                                      pred_traj_gt,
                                                      seq_scene_ids, data_dir)

                pred_traj_fake1_list.append(pred_traj_fake1)
                pred_traj_fake2_list.append(pred_traj_fake2)

            save_pickle(pred_traj_fake1_list, 'pred_traj_fake1_list',
                        selected_scene, b, data_set, model_name)
            save_pickle(pred_traj_fake2_list, 'pred_traj_fake2_list',
                        selected_scene, b, data_set, model_name)
示例#24
0
文件: train.py 项目: ACoTAI/CODE
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')
    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)
    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)
    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch)
    )
    generatorSO = TrajectoryGeneratorSO(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)

    generatorSO.apply(init_weights)
    generatorSO.type(float_dtype).train()
    logger.info('Here is the generatorSO:')
    logger.info(generatorSO)
    #TODO:generator step two
    generatorST = TrajectoryGeneratorST(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)
    generatorST.apply(init_weights)
    generatorST.type(float_dtype).train()
    logger.info('Here is the generatorST:')
    logger.info(generatorST)
    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type)
    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)
    netH = StatisticsNetwork(z_dim = 2*args.noise_dim[0] + 4*args.pred_len, dim=512)
    netH.apply(init_weights)
    netH.type(float_dtype).train()
    logger.info('Here is the netH:')
    logger.info(netH)
    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss
    optimizer_gso = optim.Adam(generatorSO.parameters(), lr=args.g_learning_rate)
    optimizer_gst = optim.Adam(generatorST.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=args.d_learning_rate)
    optimizer_h = optim.Adam(netH.parameters(), lr=args.h_learning_rate)
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generatorSO.load_state_dict(checkpoint['gso_state'])
        generatorST.load_state_dict(checkpoint['gst_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        #TODO:gso&gst
        optimizer_gso.load_state_dict(checkpoint['gso_optim_state'])
        optimizer_gst.load_state_dict(checkpoint['gst_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_gso': [],
            'norm_gst': [],
            'norm_d': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            #TODO:gso&gst
            'gso_state': None,
            'gst_state': None,
            'gso_optim_state': None,
            'gst_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'gso_best_state': None,
            'gst_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'gso_best_nl_state': None,
            'gst_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time().
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generatorSO, generatorST,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generatorSO, generatorST,
                                          discriminator, netH, g_loss_fn,
                                          optimizer_gso, optimizer_gst, optimizer_h)
                checkpoint['norm_gso'].append(
                    get_total_norm(generatorSO.parameters())
                )
                checkpoint['norm_gst'].append(
                    get_total_norm(generatorST.parameters())
                )
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1, time.time() - t0
                    ))
                t0 = time.time()
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(
                    args, val_loader, generatorSO, generatorST, discriminator, d_loss_fn
                )
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(
                    args, train_loader, generatorSO, generatorST, discriminator,
                    d_loss_fn, limit=True
                )
                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)
                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])
                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['gso_best_state'] = generatorSO.state_dict()
                    checkpoint['gst_best_state'] = generatorST.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()
                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['gso_best_nl_state'] = generatorSO.state_dict()
                    checkpoint['gst_best_nl_state'] = generatorST.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()
                checkpoint['gso_state'] = generatorSO.state_dict()
                checkpoint['gst_state'] = generatorST.state_dict()

                checkpoint['gso_optim_state'] = optimizer_gso.state_dict()
                checkpoint['gst_optim_state'] = optimizer_gst.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name
                )
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                #TODO:gso&gst
                key_blacklist = [
                    'gso_state', 'gst_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'gso_optim_state', 'gst_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
示例#25
0
def main(args):
    if args.summary_writer_name is not None:
        writer = SummaryWriter(args.summary_writer_name)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_path, args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_path, args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing val dataset")
    val_dset, val_loader = data_loader(args, val_path, shuffle=False)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path, shuffle=True)
    print(len(train_loader))

    steps = max(args.g_steps, args.c_steps)
    steps = max(steps, args.d_steps)
    iterations_per_epoch = math.ceil(len(train_dset) / args.batch_size / steps)

    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info('There are {} iterations per epoch, prints {} plots {}'.format(
        iterations_per_epoch, args.print_every, args.checkpoint_every))

    generator = helper_get_generator(args, train_path)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)
    g_loss_fn = gan_g_loss
    optimizer_g = optim.Adam(filter(lambda x: x.requires_grad,
                                    generator.parameters()),
                             lr=args.g_learning_rate)

    # build trajectory
    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        activation=args.activation,
        batch_norm=args.batch_norm,
        grid_size=args.grid_size,
        neighborhood_size=args.neighborhood_size)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)
    d_loss_fn = gan_d_loss
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr=args.d_learning_rate)

    critic = helper_get_critic(args, train_path)
    critic.apply(init_weights)
    critic.type(float_dtype).train()
    logger.info('Here is the critic:')
    logger.info(critic)
    c_loss_fn = gan_d_loss
    optimizer_c = optim.Adam(filter(lambda x: x.requires_grad,
                                    critic.parameters()),
                             lr=args.c_learning_rate)

    trajectory_evaluator = TrajectoryGeneratorEvaluator()
    if args.d_loss_weight > 0:
        logger.info('Discrimintor loss')
        trajectory_evaluator.add_module(discriminator, gan_g_loss,
                                        args.d_loss_weight)
    if args.c_loss_weight > 0:
        logger.info('Critic loss')
        trajectory_evaluator.add_module(critic, g_critic_loss_function,
                                        args.c_loss_weight)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = os.path.join(get_root_dir(), args.output_dir,
                                    args.checkpoint_start_from)
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(get_root_dir(), args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        # discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        # optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)

    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, -1
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'C_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'norm_c': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'c_state': None,
            'c_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'c_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }

    t0 = None

    # Number of times a generator, discriminator and critic steps are done in 1 epoch
    num_d_steps = ((len(train_dset) / args.batch_size) /
                   (args.g_steps + args.d_steps + args.c_steps)) * args.d_steps
    num_c_steps = ((len(train_dset) / args.batch_size) /
                   (args.g_steps + args.d_steps + args.c_steps)) * args.c_steps
    num_g_steps = ((len(train_dset) / args.batch_size) /
                   (args.g_steps + args.d_steps + args.c_steps)) * args.g_steps

    while t < args.num_iterations:
        if epoch == args.num_epochs:
            break

        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        c_steps_left = args.c_steps
        epoch += 1
        # Average losses over all batches in the training set for 1 epoch
        avg_losses_d = {}
        avg_losses_c = {}
        avg_losses_g = {}

        logger.info('Starting epoch {}  -  [{}]'.format(
            epoch, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
        for batch_num, batch in enumerate(train_loader):
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                              optimizer_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
                if len(avg_losses_d) == 0:
                    for k, v in sorted(losses_d.items()):
                        avg_losses_d[k] = v / num_d_steps
                else:
                    for k, v in sorted(losses_d.items()):
                        avg_losses_d[k] += v / num_d_steps

            elif c_steps_left > 0:
                step_type = 'c'
                losses_c = critic_step(args, batch, generator, critic,
                                       c_loss_fn, optimizer_c)
                checkpoint['norm_c'].append(get_total_norm(
                    critic.parameters()))
                c_steps_left -= 1
                if len(avg_losses_c) == 0:
                    for k, v in sorted(losses_c.items()):
                        avg_losses_c[k] = v / num_c_steps
                else:
                    for k, v in sorted(losses_c.items()):
                        avg_losses_c[k] += v / num_c_steps

            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator, optimizer_g,
                                          trajectory_evaluator)

                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters()))
                g_steps_left -= 1
                if len(avg_losses_g) == 0:
                    for k, v in sorted(losses_g.items()):
                        avg_losses_g[k] = v / num_g_steps
                else:
                    for k, v in sorted(losses_g.items()):
                        avg_losses_g[k] += v / num_g_steps

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0 or c_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Iteration {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            c_steps_left = args.c_steps

        if epoch % args.print_every == 0 and epoch > 0:
            # Save losses
            logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
            if args.d_steps > 0:
                for k, v in sorted(avg_losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                    if args.summary_writer_name is not None:
                        writer.add_scalar('Train/' + k, v, epoch)
            for k, v in sorted(avg_losses_g.items()):
                logger.info('  [G] {}: {:.3f}'.format(k, v))
                checkpoint['G_losses'][k].append(v)
                if args.summary_writer_name is not None:
                    writer.add_scalar('Train/' + k, v, epoch)
            if args.c_steps > 0:
                for k, v in sorted(avg_losses_c.items()):
                    logger.info('  [C] {}: {:.3f}'.format(k, v))
                    checkpoint['C_losses'][k].append(v)
                    if args.summary_writer_name is not None:
                        writer.add_scalar('Train/' + k, v, epoch)
            checkpoint['losses_ts'].append(t)

        if epoch % args.checkpoint_every == 0 and epoch > 0:
            # Maybe save a checkpoint
            if t > 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)
                metrics_train, metrics_val = {}, {}
                if args.g_steps > 0:
                    logger.info('Checking G stats on train ...')
                    metrics_train = check_accuracy_generator(
                        'train', epoch, args, train_loader, generator, True)

                    logger.info('Checking G stats on val ...')
                    metrics_val = check_accuracy_generator(
                        'val', epoch, args, val_loader, generator, True)

                if args.c_steps > 0:
                    logger.info('Checking C stats on train ...')
                    metrics_train_c = check_accuracy_critic(
                        args, train_loader, generator, critic, c_loss_fn, True)
                    metrics_train.update(metrics_train_c)

                    logger.info('Checking C stats on val ...')
                    metrics_val_c = check_accuracy_critic(
                        args, val_loader, generator, critic, c_loss_fn, True)
                    metrics_val.update(metrics_val_c)
                if args.d_steps > 0:
                    logger.info('Checking D stats on train ...')
                    metrics_train_d = check_accuracy_discriminator(
                        args, train_loader, generator, discriminator,
                        d_loss_fn, True)
                    metrics_train.update(metrics_train_d)

                    logger.info('Checking D stats on val ...')
                    metrics_val_d = check_accuracy_discriminator(
                        args, val_loader, generator, discriminator, d_loss_fn,
                        True)
                    metrics_val.update(metrics_val_d)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                    if args.summary_writer_name is not None:
                        writer.add_scalar('Validation/' + k, v, epoch)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)
                    if args.summary_writer_name is not None:
                        writer.add_scalar('Train/' + k, v, epoch)

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint['c_state'] = critic.state_dict()
                checkpoint['c_optim_state'] = optimizer_c.state_dict()
                checkpoint_path = os.path.join(
                    get_root_dir(), args.output_dir,
                    '{}_{}_with_model.pt'.format(args.checkpoint_name, epoch))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    get_root_dir(), args.output_dir,
                    '{}_{}_no_model.pt'.format(args.checkpoint_name, epoch))
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'c_state', 'g_best_state',
                    'g_best_nl_state', 'g_optim_state', 'd_optim_state',
                    'd_best_state', 'd_best_nl_state', 'c_optim_state',
                    'c_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

                if args.g_steps < 1:
                    continue

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()
                    if args.c_steps > 0:
                        checkpoint['c_best_state'] = critic.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

    if args.summary_writer_name is not None:
        writer.close()
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num

    # Build the training set
    logger.info("Initializing train set")
    train_path = os.path.join(args.dataset, "train")
    train_dset, train_loader = data_loader(args, train_path, "train")

    # Build the validation set
    logger.info("Initializing val set")
    val_path = os.path.join(args.dataset, "val")
    val_dset, val_loader = data_loader(args, val_path, "val")

    # set data type to cpu/gpu
    long_dtype, float_dtype = get_dtypes(args)

    # Build train val dataset
    #trainval_path = os.path.join(os.getcwd(), "dataset")
    #logger.info("Initializing train-val dataset")
    #train_dset, train_loader, _, val_loader = data_loader(args, trainval_path)

    iterations_per_epoch = train_dset / args.batch_size
    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch))

    # initialize the CNN LSTM
    classifier = CNNLSTM(embedding_dim=args.embedding_dim,
                         h_dim=args.h_dim,
                         mlp_dim=args.mlp_dim,
                         dropout=args.dropout)
    classifier.apply(init_weights)
    classifier.type(float_dtype).train()
    #input()

    #classifier = CNNMP(
    #        no_filters=32)
    #classifier.apply(init_weights)
    #classifier.type(float_dtype).train()

    # set the optimizer
    optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate)

    # define the loss function
    loss_fn = nn.CrossEntropyLoss()

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        classifier.load_state_dict(checkpoint['classifier_state'])
        optimizer.load_state_dict(checkpoint['classifier_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'classifier_losses': defaultdict(list),  # classifier loss
            'losses_ts': [],  # loss at timestep ?
            'metrics_val':
            defaultdict(list),  # valid metrics (loss and accuracy)
            'metrics_train':
            defaultdict(list),  # train metrics (loss and accuracy)
            'sample_ts': [],
            'restore_ts': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'classifier_state': None,
            'classifier_optim_state': None,
            'classifier_best_state': None,
            'best_t': None,
        }
    t0 = None
    print("Total no of iterations: ", args.num_iterations)
    while t < args.num_iterations:

        gc.collect()
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))

        for batch in train_loader:

            # Maybe save a checkpoint
            if t == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, classifier,
                                             loss_fn)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args, train_loader, classifier,
                                               loss_fn)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_loss = min(checkpoint['metrics_val']['d_loss'])
                max_acc = max(checkpoint['metrics_val']['d_accuracy'])

                if metrics_val['d_loss'] == min_loss:
                    logger.info('New low for data loss')
                    checkpoint['best_t'] = t
                    checkpoint['best_state'] = classifier.state_dict()

                if metrics_val['d_accuracy'] == max_acc:
                    logger.info('New high for accuracy')
                    checkpoint['best_t'] = t
                    checkpoint['best_state'] = classifier.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['classifier_state'] = classifier.state_dict()
                checkpoint['classifier_optim_state'] = optimizer.state_dict()
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

            #(images, labels) = batch
            # reference
            #print("batch size ", len(images))			                # batch size (total no. of sequences where each sequence can have diff. no. of images)
            #print("sequence length for sample[0] ", len(images[0]))		# number of images for sample 0
            #print("sequence length for sample[1] ", len(images[1]))
            #print("sequence length for sample[2] ", len(images[2]))
            #print("size of first image for sample[0] ", np.shape(images[0][0]))	# size of first image of sample 0

            # measure time between batches
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # run batch and get losses
            losses = step(args, batch, classifier, loss_fn, optimizer)

            # measure time between batches
            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # measure time between batches
            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1,
                        time.time() - t0))
                t0 = time.time()

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, classifier,
                                             loss_fn)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args, train_loader, classifier,
                                               loss_fn)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_loss = min(checkpoint['metrics_val']['d_loss'])
                max_acc = max(checkpoint['metrics_val']['d_accuracy'])

                if metrics_val['d_loss'] == min_loss:
                    logger.info('New low for data loss')
                    checkpoint['best_t'] = t
                    checkpoint['best_state'] = classifier.state_dict()

                if metrics_val['d_accuracy'] == max_acc:
                    logger.info('New high for accuracy')
                    checkpoint['best_t'] = t
                    checkpoint['best_state'] = classifier.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['classifier_state'] = classifier.state_dict()
                checkpoint['classifier_optim_state'] = optimizer.state_dict()
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                #checkpoint_path = os.path.join(args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                #logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                #key_blacklist = [
                #    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                #    'g_optim_state', 'd_optim_state', 'd_best_state',
                #    'd_best_nl_state'
                #]
                #small_checkpoint = {}
                #for k, v in checkpoint.items():
                #    if k not in key_blacklist:
                #        small_checkpoint[k] = v
                #torch.save(small_checkpoint, checkpoint_path)
                #logger.info('Done.')

            t += 1
            if t >= args.num_iterations:

                # print best
                #print("[train] best accuracy ", checkpoint[]
                print(
                    "[train] best accuracy at lowest loss ",
                    checkpoint['metrics_train']['d_accuracy'][np.argmin(
                        checkpoint['metrics_train']['d_loss'])])
                print("[train] best accuracy at highest accuracy ",
                      max(checkpoint['metrics_train']['d_accuracy']))
                print(
                    "[val] best accuracy at lowest loss ",
                    checkpoint['metrics_val']['d_accuracy'][np.argmin(
                        checkpoint['metrics_val']['d_loss'])])
                print("[val] best accuracy at highest accuracy ",
                      max(checkpoint['metrics_val']['d_accuracy']))

                break
示例#27
0
def main():
    results_dict = {
        'data_precondition': list(),
        'dataset': list(),
        'method': list(),
        'runtime': list(),
        'num_samples': list(),
        'num_agents': list()
    }
    data_precondition = 'curr'
    for dataset_name in ['eth', 'hotel', 'univ', 'zara1', 'zara2']:
        print('At %s dataset' % dataset_name)

        ### SGAN LOADING ###
        sgan_model_path = os.path.join(
            args.sgan_models_path, '_'.join([dataset_name, '12', 'model.pt']))

        checkpoint = torch.load(sgan_model_path, map_location='cpu')
        generator = eval_utils.get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.sgan_dset_type)
        print('Evaluating', sgan_model_path, 'on', _args.dataset_name,
              args.sgan_dset_type)

        _, sgan_data_loader = data_loader(_args, path)

        ### OUR METHOD LOADING ###
        data_dir = '../sgan-dataset/data'
        eval_data_dict_name = '%s_test.pkl' % dataset_name
        log_dir = '../sgan-dataset/logs/%s' % dataset_name

        trained_model_dir = os.path.join(
            log_dir, eval_utils.get_our_model_dir(dataset_name))
        eval_data_path = os.path.join(data_dir, eval_data_dict_name)
        with open(eval_data_path, 'rb') as f:
            eval_data_dict = pickle.load(f, encoding='latin1')
        eval_dt = eval_data_dict['dt']
        print('Loaded evaluation data from %s, eval_dt = %.2f' %
              (eval_data_path, eval_dt))

        # Loading weights from the trained model.
        specific_hyperparams = eval_utils.get_model_hyperparams(
            args, dataset_name)
        model_registrar = ModelRegistrar(trained_model_dir, args.device)
        model_registrar.load_models(specific_hyperparams['best_iter'])

        for key in eval_data_dict['input_dict'].keys():
            if isinstance(key, STGNode):
                random_node = key
                break

        hyperparams['state_dim'] = eval_data_dict['input_dict'][
            random_node].shape[2]
        hyperparams['pred_dim'] = len(eval_data_dict['pred_indices'])
        hyperparams['pred_indices'] = eval_data_dict['pred_indices']
        hyperparams['dynamic_edges'] = args.dynamic_edges
        hyperparams['edge_state_combine_method'] = specific_hyperparams[
            'edge_state_combine_method']
        hyperparams['edge_influence_combine_method'] = specific_hyperparams[
            'edge_influence_combine_method']
        hyperparams['nodes_standardization'] = eval_data_dict[
            'nodes_standardization']
        hyperparams['labels_standardization'] = eval_data_dict[
            'labels_standardization']
        hyperparams['edge_radius'] = args.edge_radius

        eval_hyperparams = copy.deepcopy(hyperparams)
        eval_hyperparams['nodes_standardization'] = eval_data_dict[
            "nodes_standardization"]
        eval_hyperparams['labels_standardization'] = eval_data_dict[
            "labels_standardization"]

        kwargs_dict = {
            'dynamic_edges':
            hyperparams['dynamic_edges'],
            'edge_state_combine_method':
            hyperparams['edge_state_combine_method'],
            'edge_influence_combine_method':
            hyperparams['edge_influence_combine_method'],
            'edge_addition_filter':
            args.edge_addition_filter,
            'edge_removal_filter':
            args.edge_removal_filter
        }

        print('-------------------------')
        print('| EVALUATION PARAMETERS |')
        print('-------------------------')
        print('| checking: %s' % data_precondition)
        print('| device: %s' % args.device)
        print('| eval_device: %s' % args.eval_device)
        print('| edge_radius: %s' % hyperparams['edge_radius'])
        print('| EE state_combine_method: %s' %
              hyperparams['edge_state_combine_method'])
        print('| EIE scheme: %s' %
              hyperparams['edge_influence_combine_method'])
        print('| dynamic_edges: %s' % hyperparams['dynamic_edges'])
        print('| edge_addition_filter: %s' % args.edge_addition_filter)
        print('| edge_removal_filter: %s' % args.edge_removal_filter)
        print('| MHL: %s' % hyperparams['minimum_history_length'])
        print('| PH: %s' % hyperparams['prediction_horizon'])
        print('| # Samples: %s' % args.num_samples)
        print('| # Runs: %s' % args.num_runs)
        print('-------------------------')

        eval_stg = OnlineSpatioTemporalGraphCVAEModel(None, model_registrar,
                                                      eval_hyperparams,
                                                      kwargs_dict,
                                                      args.eval_device)
        print('Created evaluation STG model.')

        print('About to begin evaluation computation for %s.' % dataset_name)
        with torch.no_grad():
            eval_inputs, _ = eval_utils.sample_inputs_and_labels(
                eval_data_dict, device=args.eval_device)

        (obs_traj, pred_traj_gt, obs_traj_rel, seq_start_end, data_ids,
         t_predicts) = eval_utils.get_sgan_data_format(
             eval_inputs, what_to_check=data_precondition)

        num_runs = args.num_runs
        print('num_runs, seq_start_end.shape[0]', args.num_runs,
              seq_start_end.shape[0])
        if args.num_runs > seq_start_end.shape[0]:
            print(
                'num_runs (%d) > seq_start_end.shape[0] (%d), reducing num_runs to match.'
                % (num_runs, seq_start_end.shape[0]))
            num_runs = seq_start_end.shape[0]

        random_scene_idxs = np.random.choice(seq_start_end.shape[0],
                                             size=(num_runs, ),
                                             replace=False).astype(int)

        for scene_idxs in random_scene_idxs:
            choice_list = seq_start_end[scene_idxs]

            overall_tic = time.time()
            for sample_num in range(args.num_samples):
                pred_traj_fake_rel = generator(obs_traj, obs_traj_rel,
                                               seq_start_end)
                pred_traj_fake = relative_to_abs(pred_traj_fake_rel,
                                                 obs_traj[-1])

            overall_toc = time.time()
            print('SGAN overall', overall_toc - overall_tic)
            results_dict['data_precondition'].append(data_precondition)
            results_dict['dataset'].append(dataset_name)
            results_dict['method'].append('sgan')
            results_dict['runtime'].append(overall_toc - overall_tic)
            results_dict['num_samples'].append(args.num_samples)
            results_dict['num_agents'].append(
                int(choice_list[1].item() - choice_list[0].item()))

        print('Done running SGAN')

        for node in eval_data_dict['nodes_standardization']:
            for key in eval_data_dict['nodes_standardization'][node]:
                eval_data_dict['nodes_standardization'][node][
                    key] = torch.from_numpy(
                        eval_data_dict['nodes_standardization'][node]
                        [key]).float().to(args.device)

        for node in eval_data_dict['labels_standardization']:
            for key in eval_data_dict['labels_standardization'][node]:
                eval_data_dict['labels_standardization'][node][
                    key] = torch.from_numpy(
                        eval_data_dict['labels_standardization'][node]
                        [key]).float().to(args.device)

        for run in range(num_runs):
            random_scene_idx = random_scene_idxs[run]
            data_id = data_ids[random_scene_idx]
            t_predict = t_predicts[random_scene_idx] - 1

            init_scene_dict = dict()
            for first_timestep in range(t_predict + 1):
                for node, traj_data in eval_data_dict['input_dict'].items():
                    if isinstance(node, STGNode):
                        init_pos = traj_data[data_id, first_timestep, :2]
                        if np.any(init_pos):
                            init_scene_dict[node] = init_pos

                if len(init_scene_dict) > 0:
                    break

            init_scene_graph = SceneGraph()
            init_scene_graph.create_from_scene_dict(init_scene_dict,
                                                    args.edge_radius)

            curr_inputs = {
                k: v[data_id, first_timestep:t_predict + 1]
                for k, v in eval_data_dict['input_dict'].items()
                if (isinstance(k, STGNode) and (
                    k in init_scene_graph.active_nodes))
            }
            curr_pos_inputs = {k: v[..., :2] for k, v in curr_inputs.items()}

            with torch.no_grad():
                overall_tic = time.time()
                preds_dict_most_likely = eval_stg.forward(
                    init_scene_graph,
                    curr_pos_inputs,
                    curr_inputs,
                    None,
                    hyperparams['prediction_horizon'],
                    args.num_samples,
                    most_likely=True)
                overall_toc = time.time()
                print('Our MLz overall', overall_toc - overall_tic)
                results_dict['data_precondition'].append(data_precondition)
                results_dict['dataset'].append(dataset_name)
                results_dict['method'].append('our_most_likely')
                results_dict['runtime'].append(overall_toc - overall_tic)
                results_dict['num_samples'].append(args.num_samples)
                results_dict['num_agents'].append(len(init_scene_dict))

                overall_tic = time.time()
                preds_dict_full = eval_stg.forward(
                    init_scene_graph,
                    curr_pos_inputs,
                    curr_inputs,
                    None,
                    hyperparams['prediction_horizon'],
                    args.num_samples,
                    most_likely=False)
                overall_toc = time.time()
                print('Our Full overall', overall_toc - overall_tic)
                results_dict['data_precondition'].append(data_precondition)
                results_dict['dataset'].append(dataset_name)
                results_dict['method'].append('our_full')
                results_dict['runtime'].append(overall_toc - overall_tic)
                results_dict['num_samples'].append(args.num_samples)
                results_dict['num_agents'].append(len(init_scene_dict))

        pd.DataFrame.from_dict(results_dict).to_csv(
            '../sgan-dataset/plots/data/%s_%s_runtimes.csv' %
            (data_precondition, dataset_name),
            index=False)
示例#28
0
def main(args):
    logdir = "tensorboard/" + args.dataset_name + "/" + str(
        args.num_epochs) + "_epoch_" + str(args.g_learning_rate) + "_lr"
    if os.path.exists(logdir):
        for file in os.listdir(logdir):
            os.unlink(os.path.join(logdir, file))

    writer = SummaryWriter(logdir)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')
    checkpoint_path = os.path.join(args.output_dir,
                                   '%s_with_model.pt' % args.checkpoint_name)
    pathlib.Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)

    long_dtype, float_dtype = get_dtypes(args)

    if args.moving_threshold:
        generator = TrajEstimatorThreshold(obs_len=args.obs_len,
                                           pred_len=args.pred_len,
                                           embedding_dim=args.embedding_dim,
                                           encoder_h_dim=args.encoder_h_dim_g,
                                           num_layers=args.num_layers,
                                           dropout=args.dropout)
    else:
        generator = TrajEstimator(obs_len=args.obs_len,
                                  pred_len=args.pred_len,
                                  embedding_dim=args.embedding_dim,
                                  encoder_h_dim=args.encoder_h_dim_g,
                                  num_layers=args.num_layers,
                                  dropout=args.dropout)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    generator.train()

    logger.info('Here is the generator:')
    logger.info(generator)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    iterations_per_epoch = len(train_dset) / args.batch_size
    args.num_iterations = int(iterations_per_epoch * args.num_epochs)
    # log 100 points
    log_tensorboard_every = int(args.num_iterations * 0.01) - 1
    if log_tensorboard_every <= 0:
        #there are less than 100 iterations
        log_tensorboard_every = int(args.num_iterations) / 4

    logger.info('There are {} iterations per epoch'.format(
        int(iterations_per_epoch)))

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)

    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'g_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'best_t_nl': None,
        }
    while t < args.num_iterations:
        gc.collect()

        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            losses_g = generator_step(args, batch, generator, optimizer_g,
                                      epoch)
            checkpoint['norm_g'].append(get_total_norm(generator.parameters()))

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

            # Maybe save values for tensorboard
            if t % log_tensorboard_every == 0:
                for k, v in sorted(losses_g.items()):
                    writer.add_scalar(k, v, t)

                metrics_val = check_accuracy(args, val_loader, generator,
                                             epoch)
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               epoch,
                                               limit=True)
                to_keep = ["g_l2_loss_rel", "ade", "fde"]
                for k, v in sorted(metrics_val.items()):
                    if k in to_keep:
                        writer.add_scalar("val_" + k, v, t)

                for k, v in sorted(metrics_train.items()):
                    if k in to_keep:
                        writer.add_scalar("train_" + k, v, t)

            # Maybe save a checkpoint
            if t % args.checkpoint_every == 0 and t > 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(args, val_loader, generator,
                                             epoch)
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(args,
                                               train_loader,
                                               generator,
                                               epoch,
                                               limit=True)

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            if t >= args.num_iterations:
                if args.moving_threshold:
                    logger.info(
                        "Non-moving trajectories : {}%, threshold : {}".format(
                            round(
                                (float(generator.total_trajs_under_threshold) /
                                 generator.total_trajs) * 100),
                            generator.threshold))
                break