コード例 #1
0
ファイル: adaption.py プロジェクト: qianjin5/Autoenv
def collect(egoids,
            args,
            exp_dir,
            use_hgail,
            params_filename,
            n_proc,
            collect_fn=parallel_collect_trajectories,
            random_seed=None,
            lbd=0.99,
            adapt_steps=1):
    '''
    Description:
        - prepare for running collection in parallel
        - multiagent note: egoids and starts are not currently used when running
            this with args.env_multiagent == True
    '''
    # load information relevant to the experiment
    params_filepath = os.path.join(exp_dir,
                                   'imitate/{}'.format(params_filename))
    params = np.load(params_filepath)['params'].item()
    # validation setup
    validation_dir = os.path.join(exp_dir, 'imitate', 'test')
    utils.maybe_mkdir(validation_dir)

    with Timer():
        error = collect_fn(args,
                           params,
                           egoids,
                           n_proc,
                           use_hgail=use_hgail,
                           random_seed=random_seed,
                           lbd=lbd,
                           adapt_steps=adapt_steps)

    return error
コード例 #2
0
def collect(egoids,
            starts,
            args,
            exp_dir,
            use_hgail,
            params_filename,
            n_proc,
            max_steps=200,
            collect_fn=parallel_collect_trajectories,
            random_seed=None,
            use_bc=True):
    '''
    Description:
        - prepare for running collection in parallel
        - multiagent note: egoids and starts are not currently used when running 
            this with args.env_multiagent == True 
    '''
    # load information relevant to the experiment
    params_filepath = os.path.join('../../data/experiments/single_mlp/',
                                   'imitate/log/{}'.format(params_filename))
    params = hgail.misc.utils.load_params(params_filepath)

    # validation setup
    validation_dir = os.path.join(exp_dir, 'imitate', 'validation')
    utils.maybe_mkdir(validation_dir)
    if not use_bc:
        output_filepath = os.path.join(
            validation_dir,
            '{}_trajectories.npz'.format(args.ngsim_filename.split('.')[0]))
    else:
        output_filepath = os.path.join(
            validation_dir,
            '{}_trajectories.npz'.format(args.ngsim_filename.split('.')[0]))

    with Timer():
        trajs = collect_fn(args,
                           params,
                           egoids,
                           starts,
                           n_proc,
                           max_steps=max_steps,
                           use_hgail=use_hgail,
                           random_seed=random_seed,
                           use_bc=use_bc)

    utils.write_trajectories(output_filepath, trajs)
コード例 #3
0
def collect(egoids,
            starts,
            args,
            exp_dir,
            use_hgail,
            params_filename,
            n_proc,
            max_steps=200,
            collect_fn=parallel_collect_trajectories,
            random_seed=None,
            lbd=0.99,
            adapt_steps=1):
    '''
    Description:
        - prepare for running collection in parallel
        - multiagent note: egoids and starts are not currently used when running
            this with args.env_multiagent == True
    '''
    # load information relevant to the experiment
    params_filepath = os.path.join(exp_dir,
                                   'imitate/{}'.format(params_filename))
    params = hgail.misc.utils.load_params(params_filepath)
    # validation setup
    validation_dir = os.path.join(exp_dir, 'imitate', 'test')
    utils.maybe_mkdir(validation_dir)
    output_filepath = os.path.join(
        validation_dir, '{}_AGen_{}_{}.npz'.format(
            args.ngsim_filename.split('.')[0], adapt_steps,
            args.env_multiagent))

    with Timer():
        error = collect_fn(args,
                           params,
                           egoids,
                           starts,
                           n_proc,
                           max_steps=max_steps,
                           use_hgail=use_hgail,
                           random_seed=random_seed,
                           lbd=lbd,
                           adapt_steps=adapt_steps)
    print("Vehicle Counter: {}".format(Veh_counter))
    return error