コード例 #1
0
from baselines.run import main

args = [
    '--num_timesteps',
    '50e6',
    '--env',
    'BreakoutNoFrameskip-v4',
    '--alg',
    'deepq',
]

main(args)
コード例 #2
0
ファイル: ppo2-buffer-breakout.py プロジェクト: llach/ma-code
        '10',
        '--num_env',
        '16',
        '--v_net',
        'atari',
        '--collect_until',
        '1e6',
        '--vae_buffer_size',
        '5e4',
        '--vae_batch_size',
        '1024',
        '--vae_batches_per_epoch',
        '20',
        '--seed',
        str(0),
        '--k',
        str(k),
        '--rl_coef',
        str(rlc),
        '--tensorboard',
        'True',
        '--log_interval',
        '1',
    ]

    main(args, build_fn=build_pend_env, vae_params=vae_params)
    s = get_session()
    s.close()
    tf.reset_default_graph()
print('done')
コード例 #3
0
    '--seed',
    str(0),
    '--k',
    str(k),
    '--load_path',
    policy_path,
    '--play',
    'True',
]

v = VAE(load_from=vae_name,
        network='atari',
        with_opt=False,
        session=tf.Session())
model, env = main(args,
                  build_fn=build_pend_env,
                  vae_params=vae_params,
                  just_return=True)
viewer = rendering.SimpleImageViewer()

obs = env.reset()
d = False

max_t = 25e3

t = 0
last_t = 0
max_t_ep = 2000
num_ep = 0

# buffers to be saved
org_buf = []
コード例 #4
0
ファイル: trpo-pend.py プロジェクト: llach/ma-code
                args.num_env or 1,
                seed,
                reward_scale=args.reward_scale,
                flatten_dict_observations=flatten_dict_observations)
            return VecVAEStack(env, k=k, load_from=vae_name.replace('/', ':'))

        args = [
            '--env',
            'PendulumVisual-v0',
            '--num_timesteps',
            '8e6',
            '--alg',
            'trpo_mpi',
            '--network',
            'mlp',
            '--num_env',
            '16',
            '--seed',
            str(seed),
            '--k',
            str(k),
            '--tensorboard',
            'True',
        ]

        main(args, build_fn=build_pend_env)
        s = get_session()
        s.close()
        tf.reset_default_graph()
        pass
コード例 #5
0
ファイル: ppo2-breakout-play.py プロジェクト: llach/ma-code
from baselines.run import main

args = [
    '--env',
    'BreakoutNoFrameskip-v4',
    '--alg',
    'ppo2',
    '--num_timesteps',
    '10e7',
    '--num_env',
    '1',
    '--log_interval',
    '1',
    '--seed',
    '0',
    '--load_path',
    '/Users/llach/breakout-nenv16-rlc1.0-seed0-modeld-2019-04-17T20:39/best/',
    '--play',
    'True',
]

model, env = main(args)
コード例 #6
0
import sys
import baselines.run as run

import highway_env

DEFAULT_ARGUMENTS = [
    "--env=highway-parking-v0", "--alg=her", "--num_timesteps=1e4",
    "--network=default", "--num_env=0", "--save_path=~/models/latest",
    "--load_path=~/models/latest", "--save_video_interval=0", "--play"
]

if __name__ == "__main__":
    args = sys.argv
    if len(args) <= 1:
        args = DEFAULT_ARGUMENTS
    run.main(args)
コード例 #7
0
ファイル: main.py プロジェクト: githubxxcc/moba_env
import sys, os

import baselines.run as run

os.environ['OPENAI_LOG_FORMAT'] = 'stdout,log,tensorboard'
os.environ['OPENAI_LOGDIR'] = '../_logdar/'
if __name__ == '__main__':

    run.main([
        'main.py', '--alg=ppo_moba', '--env=gym_moba:moba-multiplayer-v0',
        '--network=multi_unit_mlp', '--num_timesteps=1e7', '--scene_id=13',
        '--is_train'
    ])
    '''
    run.main(['main.py', '--alg=ppo2', '--env=Breakout-v0', 
    '--network=cnn', '--num_timesteps=2e7', '--scene_id=10', '--is_train'])
    '''
コード例 #8
0
# f = open("../models/test.out", 'w')
# sys.stdout = f
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # for remove TF warning

DEFAULT_ARGUMENTS = [
    "--env=highway-continuous-v0",
    "--alg=trpo_mpi",
    "--num_timesteps=1e6",  # episode * steps = num_timesteps = 1e6

    # policy net parameter
    "--network=mlp",
    "--num_layers=3",
    "--num_hidden=124",
    "--activation=tf.tanh",
    "--num_env=0",  # >1 for mpi, disabled for online learning
    "--env_json=C:/Users/szhan117/Documents/git_repo/highway-env/scripts/config/IDM.json",
    # last save name must be 'latest'
    "--save_path=C:/Users/szhan117/Documents/git_repo/highway-env/models/latest",
    # "--load_path=C:/Users/szhan117/Documents/git_repo/highway-env/models/latest",
    "--save_video_interval=0",
    "--play"
]

if __name__ == "__main__":
    args = sys.argv
    if len(args) <= 1:
        args = DEFAULT_ARGUMENTS
    run.main(args)  # for training
    # run.animation(args)  # for animation
    # f.close()
コード例 #9
0
ファイル: run.py プロジェクト: ihowell/r2048-rl
import sys
import r2048_rl.envs
import gym

from baselines import run

if __name__ == '__main__':
    run.main(sys.argv)
コード例 #10
0
import gym_balls  # pylint: Add to register gym-balls environments

from baselines.run import main

if __name__ == '__main__':
    main()
コード例 #11
0
#         logpath = '--log_path=./models/block/harsh_65/IR/fpp_demo25bad{}dim_{}_log'.format(dim,seed)
#         perturb = '--perturb=delay'
#         algdim = '--algdim={}'.format(dim)
# #        if seed >= 100 and seed < 1000: seed = 10
# #        elif seed >= 1000 and seed < 10000: seed = 100
# #        elif seed >= 10000 and seed < 100000: seed = 1000

#         finalargs = defaultargs + [savepath, demofile, logpath, perturb, algdim, '--seed={}'.format(seed)]

#         run.main(finalargs)

defaultargs = ['--alg=her', '--env=NuFingers', '--num_timesteps=1e6']
for dim in [4]:
    for seed in [10, 100, 1000]:
        savepath = '--save_path=./models/NuFingers/IR/fpp_demo25bad{}dim_{}'.format(
            dim, seed)
        loadpath = '--load_path=./models/NuFingers/Sim_NuFingers_bad{}dim_1000'.format(
            dim)
        demofile = '--demo_file=./NuFingersObjectDemo_bad{}dim.npz'.format(dim)
        logpath = '--log_path=./models/NuFingers/IR/fpp_demo25bad{}dim_{}_log'.format(
            dim, seed)
        perturb = '--perturb=none'
        algdim = '--algdim={}'.format(dim)

        finalargs = defaultargs + [
            savepath, loadpath, demofile, logpath, perturb, algdim,
            '--seed={}'.format(seed)
        ]

        run.main(finalargs)
コード例 #12
0
ファイル: multi_run.py プロジェクト: yoniosin/A2C_new
def run_process(args, output):
    res = main(args)
    output.put(res)
コード例 #13
0
ファイル: raw-frames.py プロジェクト: llach/ma-code
args = [
    '--env', 'BreakoutNoFrameskip-v4',
    '--alg', 'ppo2',
    '--num_timesteps', '10e7',
    '--num_env', '1',
    '--log_interval', '1',
    '--seed', '0',
    '--load_path', f'{home}/breakout-ppo/',
    '--play', 'True',
]

frames = [0, 588]

t = 0
model, env = main(args, just_return=True)
obs = env.reset()
d = False

log.info('generating frames')
while not d:

    fr = env.render(mode='rgb_array')

    if t in frames:
        scipy.misc.imsave(f'{figure_path}/frame{t}.png', fr)

    print(t)
    t += 1
    actions, _, _, _ = model.step(obs)
コード例 #14
0
ファイル: run.py プロジェクト: zain-nadeem/rl-od
import sys

from baselines.run import main

import rl_od

if __name__ == '__main__':
    main(sys.argv)
コード例 #15
0
import gym_turtlebot3
import rospy
import baselines.run as run
import os

env_name = 'TurtleBot3_Circuit_Simple-v0'
alg = 'deepq'
num_timesteps = '1e4'

name_ref = env_name + '_' + num_timesteps + '_' + alg

my_args = [
    '--alg=' + alg, 
    '--env=' + env_name, 
    '--save_path=./models/' + name_ref + '.pkl',
    '--num_timesteps=' + num_timesteps]

os.environ["OPENAI_LOG_FORMAT"] = "csv"
os.environ["OPENAI_LOGDIR"] = './logs/' + name_ref

rospy.init_node(env_name.replace('-', '_'))

run.main(my_args)