Exemplo n.º 1
0
def _data_processing(config, params):
  config.batch_shape = params.get('batch_shape', '128 2')
  config.batch_shape = tuple(map(int, config.batch_shape.split(' ')))
  config.num_chunks = int(params.get('num_chunks', 1))
  config.stack_obs = params.get('stack_obs', False)
  config.n_stack_history = params.get('n_stack_history', 2)
  image_bits = params.get('image_bits', 5)
  config.preprocess_fn = tools.bind(
      tools.preprocess.preprocess, bits=image_bits)
  config.postprocess_fn = tools.bind(
      tools.preprocess.postprocess, bits=image_bits)
  config.open_loop_context = 5
  config.data_reader = tools.numpy_episodes.episode_reader
  config.data_loader = {
      'cache': tools.bind(
          tools.numpy_episodes.cache_loader,
          every=params.get('loader_every', 1000)),
      'recent': tools.bind(
          tools.numpy_episodes.recent_loader,
          every=params.get('loader_every', 1000)),
      'reload': tools.numpy_episodes.reload_loader,
      'dummy': tools.numpy_episodes.dummy_loader,
  }[params.get('loader', 'recent')]
  config.bound_action = tools.bind(
      tools.bound_action,
      strategy=params.get('bound_action', 'clip'))
  return config
Exemplo n.º 2
0
def _define_simulation(
    task, config, params, horizon, batch_size,prefix, objective='reward',
    rewards=False):
  planner = params.get('planner', 'cem')
  objective = 'reward_int' if prefix=='train' else 'reward' #Switch between two different objectives for train and test phase
  # Temp Fix for random collections bug
  planner_iterations = params.get('planner_iterations',10)
  if params.get('planner_iterations',10)==0:
      if prefix=='train':
          planner_iterations = 0
      else:
          planner_iterations = 10
  if planner == 'cem':
    planner_fn = tools.bind(
        control.planning.cross_entropy_method,
        amount=params.get('planner_amount', 1000),
        iterations=planner_iterations,
        topk=params.get('planner_topk', 100),
        horizon=horizon)
  else:
    raise NotImplementedError(planner)
  return tools.AttrDict(
      task=task,
      num_agents=batch_size,
      planner=planner_fn,
      objective=tools.bind(getattr(objectives_lib, objective), params=params))
Exemplo n.º 3
0
def _loss_functions(config, params):
    for head in config.gradient_heads:
        assert head in config.heads, head
    config.loss_scales.divergence = params.get('divergence_scale', 1.0)
    config.loss_scales.global_divergence = params.get('global_div_scale', 0.0)
    config.loss_scales.overshooting = params.get('overshooting_scale', 0.0)
    config.r_loss = params.get('r_loss', 'nll')
    config.contra_unit = params.get('contra_unit', 'step')
    config.contra_horizon = int(params.get('contra_h', 12))
    config.resample = int(params.get('resample', 1))
    config.hard_ratio = float(params.get('hr', 1.0))
    config.temp = float(params.get('temp', 1.0))
    config.margin = float(params.get('margin', 1.0))
    for head in config.heads:
        defaults = {'reward': float(params.get('reward_loss_scale', 10.0))}
        scale = defaults[head] if head in defaults else 1.0
        config.loss_scales[head] = params.get(head + '_loss_scale', scale)
    config.free_nats = params.get('free_nats', 3.0)
    config.overshooting_distance = params.get('overshooting_distance', 0)
    config.os_stop_posterior_grad = params.get('os_stop_posterior_grad', True)
    config.optimizers = tools.AttrDict(_unlocked=True)
    config.optimizers.main = tools.bind(
        tools.CustomOptimizer,
        optimizer_cls=tools.bind(tf.train.AdamOptimizer, epsilon=1e-4),
        # schedule=tools.bind(tools.schedule.linear, ramp=0),
        learning_rate=params.get('main_learning_rate', 1e-3),
        clipping=params.get('main_gradient_clipping', 1000.0))
    return config
Exemplo n.º 4
0
def _loss_functions(config, params):
  for head in config.gradient_heads:
    assert head in config.heads, head
  config.loss_scales.divergence = params.get('divergence_scale', 1.0)
  config.loss_scales.global_divergence = params.get('global_div_scale', 0.0)
  config.loss_scales.overshooting = params.get('overshooting_scale', 0.0)
  config.free_nats = params.get('free_nats', 3.0)
  config.overshooting_distance = params.get('overshooting_distance', 0)
  config.os_stop_posterior_grad = params.get('os_stop_posterior_grad', True)
  config.optimizers = tools.AttrDict(_unlocked=True)
  config.optimizers.main = tools.bind(
      tools.CustomOptimizer,
      optimizer_cls=tools.bind(tf.train.AdamOptimizer, epsilon=1e-4),
      # schedule=tools.bind(tools.schedule.linear, ramp=0),
      learning_rate=params.get('main_learning_rate', 1e-3),
      clipping=params.get('main_gradient_clipping', 1000.0))
  for head in config.heads:
    defaults = {'reward': 10.0}
    scale = defaults[head] if head in defaults else 1.0
    config.loss_scales[head] = params.get(head + '_loss_scale', scale)
    if head in config.gradient_heads:
      continue
    config.optimizers[head] = tools.bind(
        tools.CustomOptimizer,
        optimizer_cls=tools.bind(tf.train.AdamOptimizer, epsilon=1e-4),
        # schedule=tools.bind(tools.schedule.linear, ramp=0),
        learning_rate=params.get('learning_rate', 1e-3),
        clipping=params.get('gradient_clipping', 1000.0))
  return config
Exemplo n.º 5
0
def _tasks(config, params):
    tasks = params.get('tasks', ['cheetah_run'])
    tasks = [getattr(tasks_lib, name)(config, params) for name in tasks]
    config.isolate_envs = params.get('isolate_envs', 'thread')

    def common_spaces_ctor(task, action_spaces):
        env = task.env_ctor()
        env = control.wrappers.SelectObservations(env, ['image'])
        env = control.wrappers.PadActions(env, action_spaces)
        return env

    if len(tasks) > 1:
        action_spaces = [task.env_ctor().action_space for task in tasks]
        for index, task in enumerate(tasks):
            env_ctor = tools.bind(common_spaces_ctor, task, action_spaces)
            tasks[index] = tasks_lib.Task(task.name, env_ctor, task.max_length,
                                          ['reward'])
    for name in tasks[0].state_components:
        if name == 'reward' or params.get('state_diagnostics', False):
            config.heads[name] = tools.bind(config.head_network,
                                            stop_gradient=name
                                            not in config.gradient_heads)
            config.loss_scales[name] = 1.0
    config.tasks = tasks
    return config
Exemplo n.º 6
0
def _loss_functions(config, params, cpc=False):
    for head in config.gradient_heads:
        assert head in config.heads, head
    config.loss_scales.divergence = params.get('divergence_scale', 1.0)
    config.loss_scales.global_divergence = params.get('global_div_scale', 0.0)
    config.loss_scales.overshooting = params.get('overshooting_scale', 0.0)
    config.loss_scales.inverse_model = params.get('inverse_model_scale', 0.)
    config.action_contrastive = params.get('action_contrastive', True)
    if cpc:
        config.loss_scales.cpc = params.get('cpc_scale', 100.)
        config.loss_scales.latent_prior = params.get('latent_prior_scale', 0.)
        config.cpc_reward_scale = params.get('cpc_reward_scale', 0.)
        config.cpc_gpenalty_scale = params.get('cpc_gpenalty_scale', 0.)
        config.loss_scales.embedding_l2 = params.get('embedding_l2_scale', 0.)
    for head in config.heads:
        defaults = {'reward': 10.0}
        scale = defaults[head] if head in defaults else 1.0
        config.loss_scales[head] = params.get(head + '_loss_scale', scale)

    if cpc:
        config.loss_scales['image'] = 0.

    config.free_nats = params.get('free_nats', 3.0)
    config.overshooting_distance = params.get('overshooting_distance', 0)
    config.os_stop_posterior_grad = params.get('os_stop_posterior_grad', True)
    config.optimizers = tools.AttrDict(_unlocked=True)
    config.optimizers.main = tools.bind(
        tools.CustomOptimizer,
        optimizer_cls=tools.bind(tf.train.AdamOptimizer, epsilon=1e-4),
        # schedule=tools.bind(tools.schedule.linear, ramp=0),
        learning_rate=params.get('main_learning_rate', 1e-3),
        clipping=params.get('main_gradient_clipping', 1000.0))
    return config
Exemplo n.º 7
0
def _data_processing(config, params):
    config.batch_shape = params.get('batch_shape', (50, 50))
    config.num_chunks = params.get('num_chunks', 1)
    image_bits = params.get('image_bits', 5)
    config.preprocess_fn = tools.bind(tools.preprocess.preprocess,
                                      bits=image_bits)
    config.postprocess_fn = tools.bind(tools.preprocess.postprocess,
                                       bits=image_bits)
    config.open_loop_context = 5
    config.data_reader = tools.numpy_episodes.episode_reader
    config.data_loader = {
        'cache':
        tools.bind(tools.numpy_episodes.cache_loader,
                   every=params.get('loader_every', 1000)),
        'recent':
        tools.bind(tools.numpy_episodes.recent_loader,
                   every=params.get('loader_every', 1000)),
        'reload':
        tools.numpy_episodes.reload_loader,
        'dummy':
        tools.numpy_episodes.dummy_loader,
    }[params.get('loader', 'recent')]
    config.bound_action = tools.bind(tools.bound_action,
                                     strategy=params.get(
                                         'bound_action', 'clip'))
    return config
Exemplo n.º 8
0
def _model_components(config, params):
    if not config.cpc:
        config.gradient_heads = params.get('gradient_heads',
                                           ['image', 'reward'])
    else:
        config.gradient_heads = params.get('gradient_heads', ['reward'])
    network = getattr(networks, params.get('network', 'conv_ha'))
    config.activation = ACTIVATIONS[params.get('activation', 'relu')]
    config.num_layers = params.get('num_layers', 3)
    config.num_units = params.get('num_units', 300)
    config.head_network = tools.bind(networks.feed_forward,
                                     num_layers=config.num_layers,
                                     units=config.num_units,
                                     activation=config.activation)
    config.encoder = tools.bind(network.encoder,
                                embedding_size=params.get(
                                    'embedding_size', 1024))
    if not config.cpc:
        config.decoder = network.decoder
    config.heads = tools.AttrDict(_unlocked=True)
    if not config.cpc:
        config.heads.image = config.decoder
    size = params.get('model_size', 200)
    state_size = params.get('state_size', 30)
    model = params.get('model', 'rssm')
    if model == 'ssm':
        config.cell = tools.bind(models.SSM, state_size, size,
                                 params.get('mean_only', False),
                                 config.activation,
                                 params.get('min_stddev', 1e-1))
    elif model == 'rssm':
        config.cell = tools.bind(models.RSSM, state_size, size, size,
                                 params.get('future_rnn', True),
                                 params.get('mean_only', False),
                                 params.get('min_stddev', 1e-1),
                                 config.activation,
                                 params.get('model_layers', 1))
    elif params.model == 'drnn':
        config.cell = tools.bind(models.DRNN, state_size, size, size,
                                 params.mean_only,
                                 params.get('min_stddev',
                                            1e-1), config.activation,
                                 params.get('drnn_encoder_to_decoder', False),
                                 params.get('drnn_sample_to_sample', True),
                                 params.get('drnn_sample_to_encoder', True),
                                 params.get('drnn_decoder_to_encoder', False),
                                 params.get('drnn_decoder_to_sample', True),
                                 params.get('drnn_action_to_decoder', False))
    else:
        raise NotImplementedError("Unknown model '{}.".format(params.model))
    return config
Exemplo n.º 9
0
def finger_spin(config, params):
    action_repeat = params.get('action_repeat', 2)
    max_length = 1000 // action_repeat
    state_components = ['reward', 'position', 'velocity', 'touch']
    env_ctor = tools.bind(_dm_control_env, action_repeat, max_length, 'finger',
                          'spin', params)
    return Task('finger_spin', env_ctor, max_length, state_components)
Exemplo n.º 10
0
def cartpole_swingup(config, params):
    action_repeat = params.get('action_repeat', 8)
    max_length = 1000 // action_repeat
    state_components = ['reward', 'position', 'velocity']
    env_ctor = tools.bind(_dm_control_env, action_repeat, max_length,
                          'cartpole', 'swingup', params)
    return Task('cartpole_swingup', env_ctor, max_length, state_components)
Exemplo n.º 11
0
def reacher_easy(config, params):
    action_repeat = params.get('action_repeat', 4)
    max_length = 1000 // action_repeat
    state_components = ['reward', 'position', 'velocity', 'to_target']
    env_ctor = tools.bind(_dm_control_env, action_repeat, max_length,
                          'reacher', 'easy', params)
    return Task('reacher_easy', env_ctor, max_length, state_components)
Exemplo n.º 12
0
def walker_walk(config, params):
    action_repeat = params.get('action_repeat', 2)
    max_length = 1000 // action_repeat
    state_components = ['reward', 'height', 'orientations', 'velocity']
    env_ctor = tools.bind(_dm_control_env, action_repeat, max_length, 'walker',
                          'walk', params)
    return Task('walker_walk', env_ctor, max_length, state_components)
Exemplo n.º 13
0
def cheetah_flip_forward(config, params):
    action_repeat = params.get('action_repeat', 4)
    max_length = 1000 // action_repeat
    state_components = ['reward', 'position', 'velocity']
    env_ctor = tools.bind(_dm_control_env, action_repeat, max_length,
                          'cheetah', 'flip_forward', params)
    return Task('cheetah_flip_forward', env_ctor, max_length, state_components)
Exemplo n.º 14
0
def cup_catch(config, params):
    action_repeat = params.get('action_repeat', 4)
    max_length = 1000 // action_repeat
    state_components = ['reward', 'position', 'velocity']
    env_ctor = tools.bind(_dm_control_env, action_repeat, max_length,
                          'ball_in_cup', 'catch', params)
    return Task('cup_catch', env_ctor, max_length, state_components)
Exemplo n.º 15
0
def gym_cheetah(config, params):
    # Works with `isolate_envs: process`.
    action_repeat = params.get('action_repeat', 1)
    max_length = 1000 // action_repeat
    state_components = ['reward', 'state']
    env_ctor = tools.bind(_gym_env, action_repeat, config.batch_shape[1],
                          max_length, 'HalfCheetah-v3')
    return Task('gym_cheetah', env_ctor, max_length, state_components)
Exemplo n.º 16
0
def pr2_reach(config, params):
  action_repeat = params.get('action_repeat', 1)
  max_length = 200 // action_repeat
  state_components = ['reward', 'position', 'velocity', 'ee_goal']
  env_ctor = tools.bind(
    _dm_control_env, action_repeat, max_length, 'pr2_dm', 'reach',
    params, normalize=True)
  return Task('pr2_reach', env_ctor, max_length, state_components)
Exemplo n.º 17
0
def gym_racecar(config, params):
  # Works with `isolate_envs: thread`.
  action_repeat = params.get('action_repeat', 1)
  max_length = 1000 // action_repeat
  state_components = ['reward']
  env_ctor = tools.bind(
      _gym_env, action_repeat, config.batch_shape[1], max_length,
      'CarRacing-v0', obs_is_image=True)
  return Task('gym_racing', env_ctor, max_length, state_components)
Exemplo n.º 18
0
def _define_simulation(
    task, config, params, horizon, batch_size, objective='reward',
    rewards=False):
  planner = params.get('planner', 'cem')
  if planner == 'cem':
    planner_fn = tools.bind(
        control.planning.cross_entropy_method,
        amount=params.get('planner_amount', 1000),
        iterations=params.get('planner_iterations', 10),
        topk=params.get('planner_topk', 100),
        horizon=horizon)
  else:
    raise NotImplementedError(planner)
  return tools.AttrDict(
      task=task,
      num_agents=batch_size,
      planner=planner_fn,
      objective=tools.bind(getattr(objectives_lib, objective), params=params))
Exemplo n.º 19
0
def _active_collection(collects, defaults, config, params, bs=1):
    defs = dict(
        name='main',
        batch_size=bs,
        horizon=params.get('planner_horizon', 12),
        objective=params.get('collect_objective', 'reward'),
        # after=params.get('collect_every', 5000),
        # every=params.get('collect_every', 5000),
        after=params.get('collect_every',
                         config.train_steps // config.num_clt_epoch),
        every=params.get('collect_every',
                         config.train_steps // config.num_clt_epoch),
        # until=-1,
        until=config.epoch * (config.train_steps + config.test_steps),
        action_noise=0.0,
        action_noise_ramp=params.get('action_noise_ramp', 0),
        action_noise_min=params.get('action_noise_min', 0.0),
    )
    defs.update(defaults)
    sims = tools.AttrDict(_unlocked=True)
    for task in config.tasks:
        for collect in collects:
            collect = tools.AttrDict(collect, _defaults=defs)
            sim = _define_simulation(task, config, params, collect.horizon,
                                     collect.batch_size, collect.objective)
            sim.unlock()
            sim.save_episode_dir = collect.save_episode_dir
            sim.steps_after = int(collect.after)
            sim.steps_every = int(collect.every)
            sim.steps_until = int(collect.until)
            sim.exploration = tools.AttrDict(
                scale=collect.action_noise,
                schedule=tools.bind(
                    tools.schedule.linear,
                    ramp=collect.action_noise_ramp,
                    min=collect.action_noise_min,
                ))
            name = '{}_{}_{}'.format(collect.prefix, collect.name, task.name)
            assert name not in sims, (set(sims.keys()), name)
            sims[name] = sim
            assert not collect.untouched, collect.untouched
    return sims
Exemplo n.º 20
0
def _data_processing(config, params):
    config.logdir = params.logdir
    config.bs = int(params.get('bs', 50))
    config.batch_shape = tuple(params.get('batch_shape', (config.bs, 50)))
    print(config.batch_shape, type(config.batch_shape))

    config.num_chunks = params.get('num_chunks', 1)
    config.aug = params.get('aug', None)
    config.contra_unit = params.get('contra_unit', 'step')
    print('aug ', config.aug)
    config.aug_same = bool(params.get('aug_same', False))
    image_bits = params.get('image_bits', 5)
    config.preprocess_fn = tools.bind(tools.preprocess.preprocess,
                                      bits=image_bits)
    config.postprocess_fn = tools.bind(tools.preprocess.postprocess,
                                       bits=image_bits)
    config.aug_fn = tools.bind(tools.preprocess.augment,
                               aug=config.aug,
                               same=config.aug_same,
                               simclr=config.contra_unit
                               == 'simclr') if config.aug else None
    config.open_loop_context = 5
    config.data_reader = tools.numpy_episodes.episode_reader
    config.data_loader = {
        'cache':
        tools.bind(tools.numpy_episodes.cache_loader,
                   every=params.get('loader_every', 1000)),
        'recent':
        tools.bind(tools.numpy_episodes.recent_loader,
                   every=params.get('loader_every', 1000)),
        'reload':
        tools.numpy_episodes.reload_loader,
        'dummy':
        tools.numpy_episodes.dummy_loader,
        'hard':
        tools.bind(tools.numpy_episodes.hard_negative_loader,
                   every=params.get('loader_every', 1000)),
    }[params.get('loader', 'recent')]
    print('lorder ', params.get('loader', 'recent'))
    config.bound_action = tools.bind(tools.bound_action,
                                     strategy=params.get(
                                         'bound_action', 'clip'))
    return config
Exemplo n.º 21
0
def pt_dis(config, params):
  action_repeat = params.get('action_repeat', 1)
  max_length = 100
  state_components = ['reward', 'state']
  env_ctor = tools.bind(_gym_env, action_repeat, config.batch_shape[1], max_length, 'pt_dis', gym=False)
  return Task('pt_dis', env_ctor, max_length, state_components)
Exemplo n.º 22
0
def _define_simulation(task,
                       config,
                       params,
                       horizon,
                       batch_size,
                       objective='reward',
                       rewards=False):
    config.rival = params.get('rival', '')
    config.planner = params.get('planner', 'cem')
    if config.planner == 'cem':
        print('normal cem')
        planner_fn = tools.bind(control.planning.cross_entropy_method,
                                amount=params.get('planner_amount', 1000),
                                iterations=params.get('planner_iterations',
                                                      10),
                                topk=params.get('planner_topk', 100),
                                horizon=horizon)
    elif config.planner == 'cem_eval':
        planner_fn = tools.bind(control.planning.cross_entropy_method_eval,
                                amount=params.get('planner_amount', 1000),
                                iterations=params.get('planner_iterations',
                                                      10),
                                topk=params.get('planner_topk', 100),
                                eval_ratio=params.get('eval_ratio', 0.1),
                                logdir=params.logdir,
                                horizon=horizon,
                                task=config.tasks[0])
        print('Cem_eval !!!')
    elif config.planner == 'sim':
        planner_fn = tools.bind(control.planning.simulator_planner,
                                amount=params.get('planner_amount', 1000),
                                iterations=params.get('planner_iterations',
                                                      10),
                                topk=params.get('planner_topk', 100),
                                eval_ratio=params.get('eval_ratio', 0.1),
                                logdir=params.logdir,
                                horizon=horizon,
                                task=config.tasks[0])
        print('Sim eval')
    elif config.planner == 'dual1':
        planner_fn = tools.bind(control.planning.cross_entropy_method_dual1,
                                amount=params.get('planner_amount', 1000),
                                iterations=params.get('planner_iterations',
                                                      10),
                                topk=params.get('planner_topk', 100),
                                eval_ratio=params.get('eval_ratio', 0.1),
                                logdir=params.logdir,
                                horizon=horizon,
                                task=config.tasks[0])
        print('dual1')
    elif config.planner == 'dual2':
        planner_fn = tools.bind(control.planning.cross_entropy_method_dual2,
                                amount=params.get('planner_amount', 1000),
                                iterations=params.get('planner_iterations',
                                                      10),
                                topk=params.get('planner_topk', 100),
                                eval_ratio=params.get('eval_ratio', 0.1),
                                logdir=params.logdir,
                                horizon=horizon,
                                task=config.tasks[0])
        print('dual2')
    else:
        raise NotImplementedError(config.planner)
    return tools.AttrDict(task=task,
                          num_agents=batch_size,
                          planner=planner_fn,
                          objective=tools.bind(getattr(objectives_lib,
                                                       objective),
                                               params=params))