Ejemplo n.º 1
0
def load_dataset(directory, config):
    episode = next(tools.load_episodes(directory, 1))
    types = {k: v.dtype for k, v in episode.items()}
    shapes = {k: (None, ) + v.shape[1:] for k, v in episode.items()}
    generator = lambda: tools.load_episodes(directory, config.train_steps,
                                            config.batch_length, config.
                                            dataset_balance)
    dataset = tf.data.Dataset.from_generator(generator, types, shapes)
    dataset = dataset.batch(config.batch_size, drop_remainder=True)
    dataset = dataset.map(functools.partial(preprocess, config=config))
    dataset = dataset.prefetch(10)
    return dataset
Ejemplo n.º 2
0
def main(logdir, config):
    logdir = pathlib.Path(logdir).expanduser()
    config.traindir = config.traindir or logdir / 'train_eps'
    config.evaldir = config.evaldir or logdir / 'eval_eps'
    config.steps //= config.action_repeat
    config.eval_every //= config.action_repeat
    config.log_every //= config.action_repeat
    config.time_limit //= config.action_repeat
    config.act = getattr(tf.nn, config.act)

    if config.debug:
        tf.config.experimental_run_functions_eagerly(True)
    if config.gpu_growth:
        message = 'No GPU found. To actually train on CPU remove this assert.'
        assert tf.config.experimental.list_physical_devices('GPU'), message
        for gpu in tf.config.experimental.list_physical_devices('GPU'):
            tf.config.experimental.set_memory_growth(gpu, True)
    assert config.precision in (16, 32), config.precision
    if config.precision == 16:
        prec.set_policy(prec.Policy('mixed_float16'))
    print('Logdir', logdir)
    logdir.mkdir(parents=True, exist_ok=True)
    step = count_steps(config.traindir)
    logger = tools.Logger(logdir, config.action_repeat * step)

    print('Create envs.')
    if config.offline_traindir:
        directory = config.offline_traindir.format(**vars(config))
    else:
        directory = config.traindir
    train_eps = tools.load_episodes(directory, limit=config.dataset_size)
    if config.offline_evaldir:
        directory = config.offline_evaldir.format(**vars(config))
    else:
        directory = config.evaldir
    eval_eps = tools.load_episodes(directory, limit=1)
    make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
    train_envs = [make('train') for _ in range(config.envs)]
    eval_envs = [make('eval') for _ in range(config.envs)]
    acts = train_envs[0].action_space
    config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0]

    prefill = max(0, config.prefill - count_steps(config.traindir))
    print(f'Prefill dataset ({prefill} steps).')
    random_agent = lambda o, d, s: ([acts.sample() for _ in d], s)
    tools.simulate(random_agent, train_envs, prefill)
    tools.simulate(random_agent, eval_envs, episodes=1)
    logger.step = config.action_repeat * count_steps(config.traindir)

    print('Simulate agent.')
    train_dataset = make_dataset(train_eps, config)
    eval_dataset = iter(make_dataset(eval_eps, config))
    agent = Dreamer(config, logger, train_dataset)
    if (logdir / 'variables.pkl').exists():
        agent.load(logdir / 'variables.pkl')
        agent._should_pretrain._once = False

    state = None
    while agent._step.numpy().item() < config.steps:
        logger.write()
        print('Start evaluation.')
        video_pred = agent._wm.video_pred(next(eval_dataset))
        logger.video('eval_openl', video_pred)
        eval_policy = functools.partial(agent, training=False)
        tools.simulate(eval_policy, eval_envs, episodes=1)
        print('Start training.')
        state = tools.simulate(agent,
                               train_envs,
                               config.eval_every,
                               state=state)
        agent.save(logdir / 'variables.pkl')
    for env in train_envs + eval_envs:
        try:
            env.close()
        except Exception:
            pass
Ejemplo n.º 3
0
 def generator():
     return tools.load_episodes(directory, config.train_steps,
                                config.batch_length, config.dataset_balance)