Exemple #1
0
def get_variant_spec_base(env, randomized, use_predictive_model,
                          observation_mode, reward_type, single_obj_reward,
                          all_random, trimodal_positions_choice, num_objects,
                          model_dir, num_execution_per_step, policy,
                          algorithm):
    algorithm_params = deep_update(
        ALGORITHM_PARAMS_BASE,
        ALGORITHM_PARAMS_ADDITIONAL.get(algorithm, {}),
        get_algorithm_params_roboverse(env, use_predictive_model),
    )
    variant_spec = {
        'git_sha': get_git_rev(__file__),
        'environment_params': {
            'training': {
                'env': env,
                'randomize_env': randomized,
                'use_predictive_model': use_predictive_model,
                'obs': observation_mode,
                'reward_type': reward_type,
                'single_obj_reward': single_obj_reward,
                'all_random': all_random,
                'trimodal_positions_choice': trimodal_positions_choice,
                'num_objects': num_objects,
                'model_dir': model_dir,
                'num_execution_per_step': num_execution_per_step,
                'kwargs': {
                    'image_shape': (48, 48, 3)
                }
            },
            'evaluation':
            tune.sample_from(lambda spec: (spec.get('config', spec)[
                'environment_params']['training'])),
        },

        # 'policy_params': tune.sample_from(get_policy_params),
        'policy_params': {
            'class_name': 'FeedforwardGaussianPolicy',
            'config': {
                'hidden_layer_sizes': (M, M),
                'squash': False,  #True,
                'observation_keys': None,
                'preprocessors': None,
            },
        },
        'exploration_policy_params': {
            'class_name': 'ContinuousUniformPolicy',
            'config': {
                'observation_keys':
                tune.sample_from(lambda spec: (spec.get('config', spec)[
                    'policy_params']['config'].get('observation_keys')))
            },
        },
        'Q_params': {
            'class_name': 'double_feedforward_Q_function',
            'config': {
                'hidden_layer_sizes': (M, M),
                'observation_keys': None,
                'preprocessors': None,
            },
        },
        'algorithm_params': algorithm_params,
        'replay_pool_params': {
            'class_name': 'SimpleReplayPool',
            'config': {
                'max_size': int(1e6),
            },
        },
        'sampler_params': {
            'class_name': 'SimpleSampler',
            'config': {
                'max_path_length':
                get_max_path_length_roboverse(env, use_predictive_model),
            }
        },
        'run_params': {
            'host_name': get_host_name(),
            'seed': tune.sample_from(lambda spec: np.random.randint(0, 10000)),
            'checkpoint_at_end': True,
            'checkpoint_frequency': tune.sample_from(get_checkpoint_frequency),
            'checkpoint_replay_pool': False,
        },
    }
    return variant_spec
Exemple #2
0
def get_variant_spec_base(universe, domain, task, policy, algorithm):
    algorithm_params = deep_update(
        ALGORITHM_PARAMS_BASE,
        ALGORITHM_PARAMS_ADDITIONAL.get(algorithm, {}),
        get_algorithm_params(universe, domain, task),
    )
    variant_spec = {
        'git_sha': get_git_rev(__file__),

        'environment_params': {
            'training': {
                'domain': domain,
                'task': task,
                'universe': universe,
                'kwargs': get_environment_params(universe, domain, task),
            },
            'evaluation': tune.sample_from(lambda spec: (
                spec.get('config', spec)
                ['environment_params']
                ['training']
            )),
        },
        'policy_params': tune.sample_from(get_policy_params),
        'exploration_policy_params': {
            'type': 'UniformPolicy',
            'kwargs': {
                'observation_keys': tune.sample_from(lambda spec: (
                    spec.get('config', spec)
                    ['policy_params']
                    ['kwargs']
                    .get('observation_keys')
                ))
            },
        },
        'Q_params': {
            'type': 'double_feedforward_Q_function',
            'kwargs': {
                'hidden_layer_sizes': (M, M),
                'observation_keys': None,
                'observation_preprocessors_params': {}
            },
        },
        'algorithm_params': algorithm_params,
        'replay_pool_params': {
            'type': 'SimpleReplayPool',
            'kwargs': {
                'max_size': int(1e6),
            },
        },
        'sampler_params': {
            'type': 'SimpleSampler',
            'kwargs': {
                'max_path_length': get_max_path_length(universe, domain, task),
            }
        },
        'run_params': {
            'host_name': get_host_name(),
            'seed': tune.sample_from(
                lambda spec: np.random.randint(0, 10000)),
            'checkpoint_at_end': True,
            'checkpoint_frequency': tune.sample_from(get_checkpoint_frequency),
            'checkpoint_replay_pool': False,
        },
    }

    return variant_spec
Exemple #3
0
def get_variant_spec_base(universe, domain, task, policy, algorithm):
    algorithm_params = deep_update(
        ALGORITHM_PARAMS_BASE,
        ALGORITHM_PARAMS_ADDITIONAL.get(algorithm, {}),
        get_algorithm_params(universe, domain, task),
    )
    variant_spec = {
        # doodad is complaining about this so we're just gonna hardcode a SHA
        #'git_sha': 'bd1dc29e166aca501af2e58a5057418126a3e435 master', #get_git_rev(__file__),
        'environment_params': {
            'training': {
                'domain': domain,
                'task': task,
                'universe': universe,
                'kwargs': get_environment_params(universe, domain, task),
            },
            'evaluation':
            tune.sample_from(lambda spec: (spec.get('config', spec)[
                'environment_params']['training'])),
        },
        # 'policy_params': tune.sample_from(get_policy_params),
        'policy_params': {
            'class_name': 'FeedforwardGaussianPolicy',
            'config': {
                'hidden_layer_sizes': (M, M),
                'squash': True,
                'observation_keys': None,
                'preprocessors': None,
            },
        },
        'exploration_policy_params': {
            'class_name': 'ContinuousUniformPolicy',
            'config': {
                'observation_keys':
                tune.sample_from(lambda spec: (spec.get('config', spec)[
                    'policy_params']['config'].get('observation_keys')))
            },
        },
        'Q_params': {
            'class_name': 'double_feedforward_Q_function',
            'config': {
                'hidden_layer_sizes': (M, M),
                'observation_keys': None,
                'preprocessors': None,
            },
        },
        'algorithm_params': algorithm_params,
        'replay_pool_params': {
            'class_name': 'SimpleReplayPool',
            'config': {
                'max_size': int(1e6),
            },
        },
        'sampler_params': {
            'class_name': 'SimpleSampler',
            'config': {
                'max_path_length': get_max_path_length(universe, domain, task),
            }
        },
        'run_params': {
            'host_name': get_host_name(),
            'seed': tune.sample_from(lambda spec: np.random.randint(0, 10000)),
            'checkpoint_at_end': True,
            'checkpoint_frequency': tune.sample_from(get_checkpoint_frequency),
            'checkpoint_replay_pool': False,
        },
    }

    return variant_spec