예제 #1
0
def main(env_args, agent_args, buffer_args, render=False):
    utils.set_global_seed()

    if env_args.get('n_workers', 0) > 1:
        ray.init()
    agent_name = 'Agent'
    agent = Agent(agent_name,
                  agent_args,
                  env_args,
                  save=False,
                  log=True,
                  log_tensorboard=True,
                  log_params=False,
                  log_stats=True,
                  device='/gpu:0')

    test_agent = None
    if render:
        env_args['n_envs'] = 1  # run test agent in a single environment
        env_args['log_video'] = True
        test_agent = Agent(agent_name,
                           agent_args,
                           env_args,
                           save=False,
                           log_tensorboard=False,
                           log_params=False,
                           log_stats=False,
                           device='/gpu:0',
                           reuse=True,
                           graph=agent.graph)

    train(agent, agent_args, test_agent)
예제 #2
0
파일: replay_test.py 프로젝트: xlnwel/d2rl
    def test_buffer_op(self):
        replay = create_replay(config)
        simp_replay = ReplayBuffer(config)

        env = gym.make('BipedalWalkerHardcore-v3')

        s = env.reset()
        for i in range(10000):
            a = env.action_space.sample()
            ns, r, d, _ = env.step(a)
            if d:
                ns = env.reset()
            replay.add(obs=s.astype(np.float32),
                       action=a.astype(np.float32),
                       reward=np.float32(r),
                       next_obs=ns.astype(np.float32),
                       done=d)
            simp_replay.add(obs=s, action=a, reward=r, next_obs=ns, done=d)
            s = ns

            if i > 1000:
                set_global_seed(i)
                sample1 = replay.sample()
                set_global_seed(i)
                sample2 = simp_replay.sample()

                for k in sample1.keys():
                    np.testing.assert_allclose(sample1[k],
                                               sample2[k],
                                               err_msg=f'{k}')
예제 #3
0
def main():
    set_global_seed()
    args_file = 'args.yaml'
    args = load_args(args_file)
    cmd_args = parse_cmd_args()

    if cmd_args.batch_size:
        args['eval_batch_size'] = cmd_args

    model = SAGAN('model', args, training=False)
    model.restore(cmd_args.checkpoint)
    model.evaluate(n_iterations=cmd_args.iterations)
예제 #4
0
def main(env_args, agent_args, buffer_args, render=False):
    set_global_seed()

    algorithm = agent_args['algorithm']
    buffer_args['filename'] = f'experts/data/{env_args["name"]}.pkl'
    agent = Agent(agent_args['algorithm'],
                  agent_args,
                  env_args,
                  buffer_args,
                  log_tensorboard=True,
                  log_stats=True,
                  save=True)

    train(agent, algorithm, env_args['name'])
예제 #5
0
def main():
    set_global_seed()
    args_file = 'args.yaml'
    args = load_args(args_file)
    cmd_args = parse_cmd_args()

    args['eval_image_path'] = cmd_args.image
    image = imread(cmd_args.image)
    h, w, c = image.shape
    h = ceil(h / 4) * 4
    w = ceil(w / 4) * 4
    args['image_shape'] = (h, w, c)

    model = StyleTransferModel('model', args, device='/gpu:0')
    model.restore(cmd_args.checkpoint)
    model.evaluate(eval_image=True)
예제 #6
0
파일: train.py 프로젝트: xlnwel/cv
def main(args, checkpoint=None):
    # you may need this code to train multiple instances on a single GPU
    # sess_config = tf.ConfigProto(allow_soft_placement=True)
    # sess_config.gpu_options.allow_growth=True
    # sess_config.gpu_options.per_process_gpu_memory_fraction = 0.45
    # remember to pass sess_config to Model

    model = SAGAN('model',
                  args,
                  log_tensorboard=True,
                  save=not checkpoint,
                  device='/gpu:0')
    if checkpoint:
        model.restore(checkpoint)
    model.train()


if __name__ == '__main__':
    cmd_args = parse_cmd_args()

    set_global_seed()
    args_file = 'args.yaml'

    if cmd_args.checkpoint:
        args = load_args(args_file)
        main(args, cmd_args.checkpoint)

    gs = GridSearch(args_file, main, dir_prefix=cmd_args.prefix)

    gs()