Beispiel #1
0
def _active_collection(tasks, collects, defaults, config, params):
    sims = tools.AttrDict()
    for task in tasks:
        for user_collect in collects:
            for key in user_collect:
                if key not in defaults:
                    message = 'Invalid key {} in activation collection config.'
                    raise KeyError(message.format(key))
            collect = tools.AttrDict(defaults, _unlocked=True)
            collect.update(user_collect)
            collect.planner = _define_planner(collect.planner, collect.horizon,
                                              config, params)
            collect.objective = tools.bind(getattr(objectives_lib,
                                                   collect.objective),
                                           params=params)
            if collect.give_rewards:
                collect.task = task
            else:
                env_ctor = tools.bind(
                    lambda ctor: control.wrappers.NoRewardHint(ctor()),
                    task.env_ctor)
                collect.task = tasks_lib.Task(task.name, env_ctor)
            collect.exploration = tools.AttrDict(
                scale=collect.action_noise_scale,
                type=collect.action_noise_type,
                schedule=tools.bind(tools.schedule.linear,
                                    ramp=collect.action_noise_ramp,
                                    min=collect.action_noise_min),
                factors=collect.action_noise_factors)
            name = '{}_{}_{}'.format(collect.prefix, collect.name, task.name)
            assert name not in sims, (set(sims.keys()), name)
            sims[name] = collect
    return sims
Beispiel #2
0
def gym_cheetah(config, params):
    # Works with `isolate_envs: process`.
    action_repeat = params.get('action_repeat', 1)
    state_components = ['state']
    env_ctor = tools.bind(_gym_env, 'HalfCheetah-v3', config, params,
                          action_repeat)
    return Task('gym_cheetah', env_ctor, state_components)
Beispiel #3
0
def _tasks(config, params):
    config.isolate_envs = params.get('isolate_envs', 'thread')
    train_tasks, test_tasks = [], []
    for name in params.get('tasks', ['cheetah_run']):
        try:
            train_tasks.append(
                getattr(tasks_lib, name)(config, params, 'train'))
            test_tasks.append(getattr(tasks_lib, name)(config, params, 'test'))
        except TypeError:
            train_tasks.append(getattr(tasks_lib, name)(config, params))
            test_tasks.append(getattr(tasks_lib, name)(config, params))

    def common_spaces_ctor(task, action_spaces):
        env = task.env_ctor()
        env = control.wrappers.SelectObservations(env, ['image'])
        env = control.wrappers.PadActions(env, action_spaces)
        return env

    if len(train_tasks) > 1:
        action_spaces = [task.env_ctor().action_space for task in train_tasks]
        for index, task in enumerate(train_tasks):
            env_ctor = tools.bind(common_spaces_ctor, task, action_spaces)
            train_tasks[index] = tasks_lib.Task(task.name, env_ctor, [])
    if len(test_tasks) > 1:
        action_spaces = [task.env_ctor().action_space for task in test_tasks]
        for index, task in enumerate(test_tasks):
            env_ctor = tools.bind(common_spaces_ctor, task, action_spaces)
            test_tasks[index] = tasks_lib.Task(task.name, env_ctor, [])
    if config.gradient_heads == 'all_but_image':
        config.gradient_heads = train_tasks[0].state_components
    diags = params.get('state_diagnostics', True)
    for name in train_tasks[0].state_components + ['reward', 'pcont']:
        if name not in config.gradient_heads + ['reward', 'pcont'
                                                ] and not diags:
            continue
        kwargs = {}
        kwargs['stop_gradient'] = name not in config.gradient_heads
        if name == 'pcont':
            kwargs['dist'] = 'binary'
        default = dict(reward=2).get(name, config.num_layers)
        kwargs['num_layers'] = params.get(name + '_layers', default)
        config.heads[name] = tools.bind(config.head_network, **kwargs)
        config.loss_scales[name] = 1.0
    config.train_tasks = train_tasks
    config.test_tasks = test_tasks
    return config
Beispiel #4
0
def humanoid_walk(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = [
        'reward', 'com_velocity', 'extremities', 'head_height', 'joint_angles',
        'torso_vertical', 'velocity'
    ]
    env_ctor = tools.bind(_dm_control_env, 'humanoid', 'walk', config, params,
                          action_repeat)
    return Task('humanoid_walk', env_ctor, state_components)
Beispiel #5
0
def finger_turn_hard(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = [
        'reward', 'position', 'velocity', 'touch', 'target_position',
        'dist_to_target'
    ]
    env_ctor = tools.bind(_dm_control_env, 'finger', 'turn_hard', config,
                          params, action_repeat)
    return Task('finger_turn_hard', env_ctor, state_components)
Beispiel #6
0
def _define_planner(planner, horizon, config, params):
    if planner == 'cem':
        planner_fn = tools.bind(control.planning.cross_entropy_method,
                                beams=params.get('planner_beams', 1000),
                                iterations=params.get('planner_iterations',
                                                      10),
                                topk=params.get('planner_topk', 100),
                                horizon=horizon)
    elif planner == 'policy_sample':
        planner_fn = tools.bind(control.planning.action_head_policy,
                                strategy='sample',
                                config=config)
    elif planner == 'policy_mode':
        planner_fn = tools.bind(control.planning.action_head_policy,
                                strategy='mode',
                                config=config)
    else:
        raise NotImplementedError(planner)
    return planner_fn
Beispiel #7
0
def quadruped_fetch(config, params):
    action_repeat = params.get('action_repeat', 2)
    env_ctor = tools.bind(_dm_control_env,
                          'quadruped',
                          'fetch',
                          config,
                          params,
                          action_repeat,
                          camera_id=2,
                          normalize_actions=True)
    return Task('quadruped_fetch', env_ctor, [])
Beispiel #8
0
def gym_racecar(config, params):
    # Works with `isolate_envs: thread`.
    action_repeat = params.get('action_repeat', 1)
    env_ctor = tools.bind(_gym_env,
                          'CarRacing-v0',
                          config,
                          params,
                          action_repeat,
                          select_obs=[],
                          obs_is_image=False,
                          render_mode='state_pixels')
    return Task('gym_racing', env_ctor, [])
Beispiel #9
0
def _data_processing(config, params):
    config.batch_shape = params.get('batch_shape', (50, 50))
    config.num_chunks = params.get('num_chunks', 1)
    image_bits = params.get('image_bits', 8)
    config.preprocess_fn = tools.bind(tools.preprocess.preprocess,
                                      bits=image_bits)
    config.postprocess_fn = tools.bind(tools.preprocess.postprocess,
                                       bits=image_bits)
    config.open_loop_context = 5
    config.data_reader = tools.bind(
        tools.numpy_episodes.episode_reader,
        clip_rewards=params.get('clip_rewards', False),
        pcont_scale=params.get('pcont_scale', 0.99))
    config.data_loader = {
        'cache':
        tools.bind(tools.numpy_episodes.cache_loader,
                   every=params.get('loader_every', 1000)),
        'recent':
        tools.bind(tools.numpy_episodes.recent_loader,
                   every=params.get('loader_every', 1000)),
        'window':
        tools.bind(tools.numpy_episodes.window_loader,
                   window=params.get('loader_window', 400),
                   every=params.get('loader_every', 1000)),
        'reload':
        tools.numpy_episodes.reload_loader,
        'dummy':
        tools.numpy_episodes.dummy_loader,
    }[params.get('loader', 'cache')]
    config.gpu_prefetch = params.get('gpu_prefetch', False)
    return config
Beispiel #10
0
def quadruped_walk(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = [
        'reward', 'egocentric_state', 'force_torque', 'imu', 'torso_upright',
        'torso_velocity'
    ]
    env_ctor = tools.bind(_dm_control_env,
                          'quadruped',
                          'walk',
                          config,
                          params,
                          action_repeat,
                          camera_id=2,
                          normalize_actions=True)
    return Task('quadruped_walk', env_ctor, state_components)
Beispiel #11
0
def process(logdir, args):
    with args.params.unlocked:
        args.params.logdir = logdir
    config = configs.make_config(args.params)
    logdir = pathlib.Path(logdir)
    metrics = tools.Metrics(logdir / 'metrics', workers=5)
    training.utility.collect_initial_episodes(metrics, config)
    tf.reset_default_graph()
    dataset = tools.numpy_episodes.numpy_episodes(
        config.train_dir,
        config.test_dir,
        config.batch_shape,
        reader=config.data_reader,
        loader=config.data_loader,
        num_chunks=config.num_chunks,
        preprocess_fn=config.preprocess_fn,
        gpu_prefetch=config.gpu_prefetch)
    metrics = tools.InGraphMetrics(metrics)
    build_graph = tools.bind(training.define_model, logdir, metrics)
    for score in training.utility.train(build_graph, dataset, logdir, config):
        yield score
Beispiel #12
0
def _model_components(config, params):
    config.gradient_heads = params.get('gradient_heads', ['image', 'reward'])
    config.activation = ACTIVATIONS[params.get('activation', 'elu')]
    config.num_layers = params.get('num_layers', 3)
    config.num_units = params.get('num_units', 400)
    encoder = params.get('encoder', 'conv')
    if encoder == 'conv':
        config.encoder = networks.conv.encoder
    elif encoder == 'proprio':
        config.encoder = tools.bind(networks.proprio.encoder,
                                    keys=params.get('proprio_encoder_keys'),
                                    num_layers=params.get(
                                        'proprio_encoder_num_layers', 3),
                                    units=params.get('proprio_encoder_units',
                                                     300))
    else:
        raise NotImplementedError(encoder)
    config.head_network = tools.bind(networks.feed_forward,
                                     num_layers=config.num_layers,
                                     units=config.num_units,
                                     activation=config.activation)
    config.heads = tools.AttrDict()
    if params.get('value_head', True):
        config.heads.value = tools.bind(
            config.head_network,
            num_layers=params.get('value_layers', 3),
            data_shape=[],
            dist=params.get('value_dist', 'normal'))
    if params.get('value_target_head', False):
        config.heads.value_target = tools.bind(
            config.head_network,
            num_layers=params.get('value_layers', 3),
            data_shape=[],
            stop_gradient=True,
            dist=params.get('value_dist', 'normal'))
    if params.get('return_head', False):
        config.heads['return'] = tools.bind(config.head_network,
                                            activation=config.activation)
    if params.get('action_head', True):
        config.heads.action = tools.bind(
            config.head_network,
            num_layers=params.get('action_layers', 4),
            mean_activation=ACTIVATIONS[params.get('action_mean_activation',
                                                   'none')],
            dist=params.get('action_head_dist', 'tanh_normal_tanh'),
            std=params.get('action_head_std', 'learned'),
            min_std=params.get('action_head_min_std', 1e-4),
            init_std=params.get('action_head_init_std', 5.0))
    if params.get('action_target_head', False):
        config.heads.action_target = tools.bind(
            config.head_network,
            num_layers=params.get('action_layers', 4),
            stop_gradient=True,
            mean_activation=ACTIVATIONS[params.get('action_mean_activation',
                                                   'none')],
            dist=params.get('action_head_dist', 'tanh_normal_tanh'),
            std=params.get('action_head_std', 'learned'),
            min_std=params.get('action_head_min_std', 1e-4),
            init_std=params.get('action_head_init_std', 5.0))
    if params.get('cpc_head', False):
        config.heads.cpc = config.head_network.copy(
            dist=params.get('cpc_head_dist', 'normal'),
            std=params.get('cpc_head_std', 'learned'),
            num_layers=params.get('cpc_head_layers', 3))
    image_head = params.get('image_head', 'conv')
    if image_head == 'conv':
        config.heads.image = tools.bind(networks.conv.decoder,
                                        std=params.get('image_head_std', 1.0))
    else:
        raise NotImplementedError(image_head)
    hidden_size = params.get('model_size', 200)
    state_size = params.get('state_size', 30)
    model = params.get('model', 'rssm')
    if model == 'rssm':
        config.cell = tools.bind(models.RSSM, state_size, hidden_size,
                                 hidden_size, params.get('future_rnn', True),
                                 params.get('mean_only', False),
                                 params.get('min_stddev',
                                            1e-1), config.activation,
                                 params.get('model_layers', 1),
                                 params.get('rssm_model', 'gru'),
                                 params.get('trxl_layer', 2),
                                 params.get('trxl_n_head', 10),
                                 params.get('trxl_mem_len', 8),
                                 params.get('trxl_pre_lnorm', False),
                                 params.get('trxl_gate', 'plus'))
    else:
        raise NotImplementedError(model)
    return config
Beispiel #13
0
def walker_stand(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = ['height', 'orientations', 'velocity']
    env_ctor = tools.bind(_dm_control_env, 'walker', 'stand', config, params,
                          action_repeat)
    return Task('walker_stand', env_ctor, state_components)
Beispiel #14
0
def cup_catch(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = ['position', 'velocity']
    env_ctor = tools.bind(_dm_control_env, 'ball_in_cup', 'catch', config,
                          params, action_repeat)
    return Task('cup_catch', env_ctor, state_components)
Beispiel #15
0
def cheetah_run(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = ['position', 'velocity']
    env_ctor = tools.bind(_dm_control_env, 'cheetah', 'run', config, params,
                          action_repeat)
    return Task('cheetah_run', env_ctor, state_components)
Beispiel #16
0
def pointmass_easy(config, params):
    action_repeat = params.get('action_repeat', 2)
    env_ctor = tools.bind(_dm_control_env, 'point_mass', 'easy', config,
                          params, action_repeat)
    return Task('pointmass_easy', env_ctor, [])
Beispiel #17
0
def reacher_hard(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = ['position', 'velocity', 'to_target']
    env_ctor = tools.bind(_dm_control_env, 'reacher', 'hard', config, params,
                          action_repeat)
    return Task('reacher_hard', env_ctor, state_components)
Beispiel #18
0
    'dmlab_keys': 'rooms_keys_doors_puzzle',
    'dmlab_lasertag_1': 'lasertag_one_opponent_small',
    'dmlab_lasertag_3': 'lasertag_three_opponents_small',
    'dmlab_recognize': 'psychlab_visual_search',
    'dmlab_watermaze': 'rooms_watermaze',
}

for name, level in DMLAB_TASKS.items():

    def task_fn(name, level, config, params):
        action_repeat = params.get('action_repeat', 4)
        env_ctor = tools.bind(_dm_lab_env, level, config, params,
                              action_repeat)
        return Task(name, env_ctor, [])

    locals()[name] = tools.bind(task_fn, name, level)

ATARI_TASKS = [
    'Alien', 'Amidar', 'Assault', 'Asterix', 'Asteroids', 'Atlantis',
    'BankHeist', 'BattleZone', 'BeamRider', 'Berzerk', 'Bowling', 'Boxing',
    'Breakout', 'Centipede', 'ChopperCommand', 'CrazyClimber', 'Defender',
    'DemonAttack', 'DoubleDunk', 'Enduro', 'FishingDerby', 'Freeway',
    'Frostbite', 'Gopher', 'Gravitar', 'Hero', 'IceHockey', 'Jamesbond',
    'Kangaroo', 'Krull', 'KungFuMaster', 'MontezumaRevenge', 'MsPacman',
    'NameThisGame', 'Phoenix', 'Pitfall', 'Pong', 'PrivateEye', 'Qbert',
    'Riverraid', 'RoadRunner', 'Robotank', 'Seaquest', 'Skiing', 'Solaris',
    'SpaceInvaders', 'StarGunner', 'Surround', 'Tennis', 'TimePilot',
    'Tutankham', 'UpNDown', 'Venture', 'VideoPinball', 'WizardOfWor',
    'YarsRevenge', 'Zaxxon'
]
ATARI_TASKS = {'atari_{}'.format(game.lower()): game for game in ATARI_TASKS}
Beispiel #19
0
 def task_fn(name, game, config, params, mode):
     action_repeat = params.get('action_repeat', 1)
     env_ctor = tools.bind(_procgen_env, game, mode, config, params,
                           action_repeat)
     return Task(name, env_ctor, [])
Beispiel #20
0
 def task_fn(name, level, config, params):
     action_repeat = params.get('action_repeat', 4)
     env_ctor = tools.bind(_dm_lab_env, level, config, params,
                           action_repeat)
     return Task(name, env_ctor, [])
Beispiel #21
0
def acrobot_swingup(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = ['orientations', 'velocity']
    env_ctor = tools.bind(_dm_control_env, 'acrobot', 'swingup', config,
                          params, action_repeat)
    return Task('acrobot_swingup', env_ctor, state_components)
Beispiel #22
0
def fish_upright(config, params):
    action_repeat = params.get('action_repeat', 2)
    env_ctor = tools.bind(_dm_control_env, 'fish', 'upright', config, params,
                          action_repeat)
    return Task('fish_upright', env_ctor, [])
Beispiel #23
0
def _loss_functions(config, params):
    for head in config.gradient_heads:
        assert head in config.heads, head
    config.imagination_horizon = params.get('imagination_horizon', 15)
    config.imagination_skip_last = params.get('imagination_skip_last', None)
    config.imagination_include_initial = params.get(
        'imagination_include_initial', True)

    config.action_source = params.get('action_source', 'model')
    config.action_model_horizon = params.get('action_model_horizon', None)
    config.action_bootstrap = params.get('action_bootstrap', True)
    config.action_discount = params.get('action_discount', 0.99)
    config.action_lambda = params.get('action_lambda', 0.95)
    config.action_target_update = params.get('action_target_update', 1)
    config.action_target_period = params.get('action_target_period', 50000)
    config.action_loss_pcont = params.get('action_loss_pcont', False)
    config.action_pcont_stop_grad = params.get('action_pcont_stop_grad', False)
    config.action_pcont_weight = params.get('action_pcont_weight', True)

    config.value_source = params.get('value_source', 'model')
    config.value_model_horizon = params.get('value_model_horizon', None)
    config.value_discount = params.get('value_discount', 0.99)
    config.value_lambda = params.get('value_lambda', 0.95)
    config.value_bootstrap = params.get('value_bootstrap', True)
    config.value_target_update = params.get('value_target_update', 1)
    config.value_target_period = params.get('value_target_period', 50000)
    config.value_loss_pcont = params.get('value_loss_pcont', False)
    config.value_pcont_weight = params.get('value_pcont_weight', True)
    config.value_maxent = params.get('value_maxent', False)

    config.action_beta = params.get('action_beta', 0.0)
    config.action_beta_dims_value = params.get('action_beta_dims_value', None)
    config.state_beta = params.get('state_beta', 0.0)
    config.stop_grad_pre_action = params.get('stop_grad_pre_action', True)
    config.pcont_label_weight = params.get('pcont_label_weight', None)

    config.loss_scales.divergence = params.get('divergence_scale', 1.0)
    config.loss_scales.global_divergence = params.get('global_div_scale', 0.0)
    config.loss_scales.overshooting = params.get('overshooting_scale', 0.0)
    for head in config.heads:
        if head in ('value_target', 'action_target'):  # Untrained.
            continue
        config.loss_scales[head] = params.get(head + '_loss_scale', 1.0)

    config.free_nats = params.get('free_nats', 3.0)
    config.overshooting_distance = params.get('overshooting_distance', 0)
    config.os_stop_posterior_grad = params.get('os_stop_posterior_grad', True)
    config.cpc_contrast = params.get('cpc_contrast', 'window')
    config.cpc_batch_amount = params.get('cpc_batch_amount', 10)
    config.cpc_time_amount = params.get('cpc_time_amount', 30)

    optimizer_cls = tools.bind(tf.train.AdamOptimizer,
                               epsilon=params.get('optimizer_epsilon', 1e-4))
    config.optimizers = tools.AttrDict()
    config.optimizers.default = tools.bind(
        tools.CustomOptimizer,
        optimizer_cls=optimizer_cls,
        # schedule=tools.bind(tools.schedule.linear, ramp=0),
        learning_rate=params.get('default_lr', 1e-3),
        clipping=params.get('default_gradient_clipping', 1000.0))
    config.optimizers.model = config.optimizers.default.copy(
        learning_rate=params.get('model_lr', 6e-4),
        clipping=params.get('model_gradient_clipping', 100.0))
    config.optimizers.value = config.optimizers.default.copy(
        learning_rate=params.get('value_lr', 8e-5),
        clipping=params.get('value_gradient_clipping', 100.0))
    config.optimizers.action = config.optimizers.default.copy(
        learning_rate=params.get('action_lr', 8e-5),
        clipping=params.get('action_gradient_clipping', 100.0))
    return config
Beispiel #24
0
def cartpole_swingup(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = ['position', 'velocity']
    env_ctor = tools.bind(_dm_control_env, 'cartpole', 'swingup', config,
                          params, action_repeat)
    return Task('cartpole_swingup', env_ctor, state_components)
Beispiel #25
0
def finger_spin(config, params):
    action_repeat = params.get('action_repeat', 2)
    state_components = ['position', 'velocity', 'touch']
    env_ctor = tools.bind(_dm_control_env, 'finger', 'spin', config, params,
                          action_repeat)
    return Task('finger_spin', env_ctor, state_components)
Beispiel #26
0
def manipulator_bring(config, params):
    action_repeat = params.get('action_repeat', 2)
    env_ctor = tools.bind(_dm_control_env, 'manipulator', 'bring_ball', config,
                          params, action_repeat)
    return Task('manipulator_bring_ball', env_ctor, [])