Beispiel #1
0
 def _prepare_network(self):
     net_type = import_function(self._network_params['net_type'])
     self._logger.debug(f'Using {self._network_params["net_type"]}')
     params = deepcopy(self._network_params)
     del params['net_type']
     net = net_type(self._batch_tf, self._normalizer_tf, **params)
     return net
Beispiel #2
0
 def _prepare_network(self):
     net_type = import_function(self._network_params['net_type'])
     self._logger.debug(f'Using {self._network_params["net_type"]}')
     params = deepcopy(self._network_params)
     del params['net_type']
     params['task_from_id'] = self._task_from_id
     params['task_to_id'] = self._task_to_id
     params['tasks_specs'] = self._tasks_specs
     net = net_type(self, self._batch_tf, self._normalizer_tf, **params)
     return net
Beispiel #3
0
    def __init__(self,
                 inputs,
                 normalizer,
                 nL=(255, 255),
                 activation='tensorflow.nn:tanh',
                 layer_norm=False,
                 scope='mlp'):
        self._sess = tf.get_default_session() or tf.InteractiveSession()

        activation = import_function(activation)

        self._o = inputs['o']
        self._a = inputs['a']
        self._o_next = inputs['o_next'][:, -1]
        self._layer_norm = layer_norm

        o_norm = normalizer['o'].normalize(self._o)
        a_norm = normalizer['a'].normalize(self._a)
        o_next_norm = normalizer['o'].normalize(self._o_next)
        state_diff = o_next_norm - o_norm[:, -1]

        o_norm_flat = tf.reshape(o_norm, (-1, np.prod(o_norm.shape[1:])))
        a_norm_flat = tf.reshape(a_norm, (-1, np.prod(a_norm.shape[1:])))

        l = tf.concat([o_norm_flat, a_norm_flat], axis=1)

        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            for i, n in enumerate(nL):
                l = tf.layers.dense(l, n, name=f'layer{i}')
                if self._layer_norm:
                    l = tf.contrib.layers.layer_norm(l)
                l = activation(l)

            self._predicted_state_diff = tf.layers.dense(l,
                                                         self._o_next.shape[1],
                                                         name='final_layer')

        self._predicted_state = normalizer['o'].denormalize(
            o_norm[:, -1] + self._predicted_state_diff)
        self._prediction_error = self._o_next - self._predicted_state

        self._loss = tf.losses.mean_squared_error(state_diff,
                                                  self._predicted_state_diff)
Beispiel #4
0
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
#####################################################################

import os
import sys
import tempfile
from utils import sh, sh_str, e, setup_env, objdir, info, import_function

create_aux_files = import_function('create-release-distribution',
                                   'create_aux_files')


def main():
    changelog = e('${CHANGELOG}')
    ssh = e('${UPDATE_USER}@${UPDATE_HOST}')
    sshopts = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
    temp_dest = sh_str(
        "ssh ${ssh} ${sshopts} mktemp -d /tmp/update-${PRODUCT}-XXXXXXXXX")
    temp_changelog = sh_str(
        "ssh ${ssh} ${sshopts} mktemp /tmp/changelog-XXXXXXXXX")

    if not temp_dest or not temp_changelog:
        fail('Failed to create temporary directories on {0}', ssh)

    sh('scp ${sshopts} -r ${BE_ROOT}/release/LATEST/. ${ssh}:${temp_dest}')
Beispiel #5
0
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
#####################################################################


import os
from dsl import load_file, load_profile_config
from utils import sh, info, objdir, e, chroot, glob, readfile, setfile, template, sha256, import_function, get_port_names, on_abort


ports = load_file('${BUILD_CONFIG}/ports-installer.pyd', os.environ)
installworld = import_function('build-os', 'installworld')
installkernel = import_function('build-os', 'installkernel')
installworldlog = objdir('logs/iso-installworld')
installkernellog = objdir('logs/iso-installkernel')
distributionlog = objdir('logs/iso-distribution')
sysinstalllog = objdir('logs/iso-sysinstall')
imgfile = objdir('base.ufs')
output = objdir('${NAME}.iso')


purge_dirs = [
    '/bin',
    '/sbin',
    '/usr/bin',
    '/usr/sbin'
]
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
#####################################################################

import os
from dsl import load_file, load_profile_config
from utils import sh, info, objdir, e, chroot, glob, readfile, setfile, template, sha256, import_function, get_port_names, on_abort, is_elf

ports = load_file('${BUILD_CONFIG}/ports-installer.pyd', os.environ)
installworld = import_function('build-os', 'installworld')
installkernel = import_function('build-os', 'installkernel')
installworldlog = objdir('logs/iso-installworld')
installkernellog = objdir('logs/iso-installkernel')
distributionlog = objdir('logs/iso-distribution')
sysinstalllog = objdir('logs/iso-sysinstall')
imgfile = objdir('base.ufs')
output = objdir('${NAME}.iso')

purge_dirs = ['/bin', '/sbin', '/usr/bin', '/usr/sbin']

files_to_preserve = [
    '/bin/sleep', '/usr/bin/dialog', '/usr/bin/dirname', '/usr/bin/awk',
    '/usr/bin/cut', '/usr/bin/cmp', '/usr/bin/find', '/usr/bin/grep',
    '/usr/bin/logger', '/usr/bin/mkfifo', '/usr/bin/mktemp', '/usr/bin/sed',
    '/usr/bin/sort', '/usr/bin/scp', '/usr/bin/sftp', '/usr/bin/ssh',
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
#####################################################################


import os
from dsl import load_file
from utils import e, sh, setup_env, import_function, env


dsl = load_file('${BUILD_CONFIG}/upgrade.pyd', os.environ)
create_aux_files = import_function('create-release-distribution', 'create_aux_files')


def stage_upgrade():
    sh('rm -rf ${UPGRADE_STAGEDIR}')
    sh('mkdir -p ${UPGRADE_STAGEDIR}')
    sh('cp -R ${OBJDIR}/packages/Packages ${UPGRADE_STAGEDIR}/')
    # If RESTART is given, save that
    if env('RESTART'):
       sh('echo ${RESTART} > ${UPGRADE_STAGEDIR}/RESTART')

    # And if REBOOT is given, put that in FORCEREBOOT
    if env('REBOOT'):
       sh('echo ${REBOOT} > ${UPGRADE_STAGEDIR}/FORCEREBOOT')
    sh('rm -f ${BE_ROOT}/release/LATEST')
    sh('ln -sf ${UPGRADE_STAGEDIR} ${BE_ROOT}/release/LATEST')
Beispiel #8
0
Datei: ddpg.py Projekt: s-bl/cwyc
    def __init__(self,
                 env_spec,
                 task_spec,
                 buffer_size,
                 network_params,
                 normalizer_params,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 random_eps,
                 noise_eps,
                 train_steps,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 replay_strategy,
                 replay_k,
                 noise_type,
                 share_experience,
                 noise_adaptation,
                 reuse=False):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
            Added functionality to use demonstrations for training to Overcome exploration problem.

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
            bc_loss: whether or not the behavior cloning loss should be used as an auxilliary loss
            q_filter: whether or not a filter on the q value update should be used when training with demonstartions
            num_demo: Number of episodes in to be used in the demonstration buffer
            demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
            prm_loss_weight: Weight corresponding to the primary loss
            aux_loss_weight: Weight corresponding to the auxilliary loss also called the cloning loss
        """
        super().__init__(scope)
        self.replay_k = replay_k
        self.replay_strategy = replay_strategy
        self.clip_pos_returns = clip_pos_returns
        self.relative_goals = relative_goals
        self.train_steps = train_steps
        self.noise_eps = noise_eps
        self.random_eps = random_eps
        self.clip_obs = clip_obs
        self.action_l2 = action_l2
        self.max_u = max_u
        self.pi_lr = pi_lr
        self.Q_lr = Q_lr
        self.batch_size = batch_size
        self.normalizer_params = normalizer_params
        self.polyak = polyak
        self.buffer_size = buffer_size
        self._env_spec = env_spec
        self._T = self._env_spec['T']
        self._task_spec = task_spec
        self.network_params = network_params
        self._share_experience = share_experience
        self._noise_adaptation = noise_adaptation

        self._task_spec = deepcopy(task_spec)
        self._task_spec['buffer_size'] = 0
        self._task = Task(**self._task_spec)

        self._gamma = 1. - 1. / self._T
        self.clip_return = (1. / (1. - self._gamma)) if clip_return else np.inf

        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(network_params['net_type'])

        self.input_dims = dict(
            o=self._env_spec['o_dim'],
            a=self._env_spec['a_dim'],
            g=self._task_spec['g_dim'],
        )

        input_shapes = dims_to_shapes(self.input_dims)

        self.dimo = self._env_spec['o_dim']
        self.dimg = self._task_spec['g_dim']
        self.dima = self._env_spec['a_dim']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_next'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        self._action_noise, self._parameter_noise = get_noise_from_string(
            self._env_spec, noise_type)

        # Create network.
        with tf.variable_scope(self._scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        buffer_shapes = dict()
        buffer_shapes['o'] = (self.dimo, )
        buffer_shapes['o_next'] = buffer_shapes['o']
        buffer_shapes['g'] = (self.dimg, )
        buffer_shapes['ag'] = (self.dimg, )
        buffer_shapes['ag_next'] = (self.dimg, )
        buffer_shapes['a'] = (self.dima, )

        self.sample_transitions = make_sample_her_transitions(
            self.replay_strategy, self.replay_k,
            self._task.reward_done_success)

        self._buffer = ReplayBuffer(buffer_shapes, self.buffer_size, self._T,
                                    self.sample_transitions)
Beispiel #9
0
def basic_configure(**kwargs):
    ##########################
    #     load experiment    #
    ##########################

    experiments_path = []
    if 'experiment' in kwargs and kwargs['experiment'] is not None:
        experiments_kwargs = []
        experiments_path = [os.path.splitext(os.path.basename(kwargs['experiment']))[0]]
        experiment_basedir = os.path.dirname(kwargs['experiment'])
        while True:
            with open(os.path.join(experiment_basedir, experiments_path[-1] + '.json'), 'r') as f:
                experiments_kwargs.append(json.load(f))
            if experiments_kwargs[-1]['inherit_from'] is not None:
                experiments_path.append(experiments_kwargs[-1]['inherit_from'])
                continue
            break
        for experiment_kwargs in reversed(experiments_kwargs):
            update_default_params(kwargs, experiment_kwargs)

    ##########################
    #     load json string   #
    ##########################

    if 'json_string' in kwargs and kwargs['json_string'] is not None:
        update_default_params(kwargs, json.loads(kwargs['json_string']))

    ##########################
    #     Prepare logging    #
    ##########################

    clean = get_parameter('clean', params=kwargs, default=False)
    jobdir = get_parameter('jobdir', params=kwargs, default=mkdtemp())

    if clean and os.path.exists(jobdir) and not os.path.exists(os.path.join(jobdir, 'restart')):
        rmtree(jobdir)

    os.makedirs(jobdir, exist_ok=True)

    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] <%(levelname)s> %(name)s: %(message)s',
                        datefmt='%m/%d/%Y %I:%M:%S %p',
                        handlers=([
                                   logging.FileHandler(os.path.join(jobdir, 'events.log'))] +
                                  [logging.StreamHandler(sys.stdout)])
                        )

    summary_writer = SummaryWriter(os.path.join(jobdir, 'tf_board'))

    if clean: logger.info(f'Cleaned jobdir {jobdir}')

    for experiment_path in reversed(experiments_path):
        logger.info(f'Loaded params from experiment {experiment_path}')

    project_path = os.path.dirname(os.path.realpath(__file__))
    try:
        repo = Repo(project_path, search_parent_directories=True)
        active_branch = repo.active_branch
        latest_commit = repo.commit(active_branch)
        latest_commit_sha = latest_commit.hexsha
        latest_commit_sha_short = repo.git.rev_parse(latest_commit_sha, short=6)
        logger.info(f'We are on branch {active_branch} using commit {latest_commit_sha_short}')
    except InvalidGitRepositoryError:
        logger.warn(f'{project_path} is not a git repo')

    ##########################
    #    Continue training   #
    ##########################

    restart_after = get_parameter('restart_after', params=kwargs, default=None)
    continued_params = {}
    if restart_after is not None and os.path.exists(os.path.join(jobdir, 'restart')):
        with open(os.path.join(jobdir, 'basic_params.json'), 'r') as f:
            continued_params = json.load(f)

    ##########################
    #  Load external config  #
    ##########################

    basic_params_path = get_parameter('basic_params_path', params=continued_params, default=None)
    basic_params_path = get_parameter('basic_params_path', params=kwargs, default=basic_params_path)
    external_params = {}
    if basic_params_path is not None:
        with open(basic_params_path, 'r') as f:
            external_params = json.load(f)

    ##########################
    #        Seeding         #
    ##########################

    seed = get_parameter('seed', params=external_params, default=int(np.random.random_integers(0, 2**23-1)))
    seed = get_parameter('seed', params=continued_params, default=seed)
    seed = get_parameter('seed', params=kwargs, default=seed)

    logger.info(f'Using seed {seed}')

    ####################
    #    Prepare env   #
    ####################

    env_spec = get_parameter('env', params=external_params, default=None)
    env_spec = get_parameter('env', params=continued_params, default=env_spec)
    env_spec = get_parameter('env', params=kwargs, default=env_spec)

    env_params = dict()
    update_default_params(env_params, external_params.get('env_params', {}))
    update_default_params(env_params, continued_params.get('env_params', {}))
    update_default_params(env_params, kwargs.get('env_params', {}))

    env_proto = import_function(env_spec)
    tmp_env = env_proto(**env_params)
    obs = tmp_env.reset()
    env_spec = dict(
        o_dim=obs['observation'].shape[0],
        a_dim=tmp_env.action_space.shape[0],
        g_dim=obs['desired_goal'].shape[0],
    )
    if hasattr(tmp_env, 'goal_min'): env_spec['goal_min'] = tmp_env.goal_min
    if hasattr(tmp_env, 'goal_max'): env_spec['goal_max'] = tmp_env.goal_max
    update_default_params(env_spec, external_params.get('env_spec', {}))
    update_default_params(env_spec, continued_params.get('env_spec', {}))
    update_default_params(env_spec, kwargs.get('env_spec', {}))

    T = get_parameter('T', params=env_spec, default=800)


    env_fn = (env_proto, env_params)

    ####################
    #   Prepare tasks  #
    ####################

    tasks_specs = []
    update_default_params(tasks_specs, external_params.get('tasks_specs', {}))
    update_default_params(tasks_specs, continued_params.get('tasks_specs', {}))
    update_default_params(tasks_specs, kwargs.get('tasks_specs', {}))

    tasks_specs = [task_spec for task_spec in tasks_specs if task_spec.get('active', True)]

    tasks_fn = []
    for task_spec in tasks_specs:
        if 'active' in task_spec: del(task_spec['active'])
        task_spec['id'] = len(tasks_fn)
        task_spec['scope'] = f'Task{task_spec["id"]}'
        tasks_fn.append((Task, task_spec))

    ####################
    # Prepare policies #
    ####################

    policy_params = dict()
    update_default_params(policy_params, external_params.get('policy_params', {}))
    update_default_params(policy_params, continued_params.get('policy_params', {}))
    update_default_params(policy_params, kwargs.get('policy_params', {}))

    assert 'policy_type' in policy_params

    policy_proto = import_function(policy_params['policy_type'])

    policies_fn = []
    policies_params = []
    for task_spec in tasks_specs:
        params = deepcopy(policy_params)
        params['env_spec'] = env_spec
        params['task_spec'] = task_spec
        params['scope'] = f'policy_{task_spec["id"]}'
        del params['policy_type']
        policies_params.append(params)
        policies_fn.append((policy_proto, params))

    #########################
    # Prepare task selector #
    #########################

    task_selector_params = dict(
        tasks_specs=tasks_specs,
        surprise_weighting=0.1,
        buffer_size=100,
        lr=0.1,
        reg=1e-3,
        precision=1e-3,
        eps_greedy_prob=0.05,
        surprise_hist_weighting=.99,
        scope='taskSelector',
        fixed_Q=None,
        epsilon=0.1,
    )
    update_default_params(task_selector_params, external_params.get('task_selector_params', {}))
    update_default_params(task_selector_params, continued_params.get('task_selector_params', {}))
    update_default_params(task_selector_params, kwargs.get('task_selector_params', {}))

    task_selector_fn = (TaskSelector, task_selector_params)

    #########################
    # Prepare task planner  #
    #########################

    task_planner_params = dict(
        env_specs=env_spec,
        tasks_specs=tasks_specs,
        surprise_weighting=0.001,
        surprise_hist_weighting=.99,
        buffer_size=100,
        eps_greedy_prob=0.05,
        max_seq_length=10,
        scope='taskPlanner',
        fixed_Q=None,
        epsilon=0.0001,
    )
    update_default_params(task_planner_params, external_params.get('task_planner_params', {}))
    update_default_params(task_planner_params, continued_params.get('task_planner_params', {}))
    update_default_params(task_planner_params, kwargs.get('task_planner_params', {}))

    task_planner_fn = (TaskPlanner, task_planner_params)

    #########################
    #     Prepare gnets     #
    #########################

    gnet_params = dict(
        env_spec=env_spec,
        tasks_specs=tasks_specs,
        pos_buffer_size=int(1e3),
        neg_buffer_size=int(1e5),
        batch_size=64,
        learning_rate=1e-4,
        train_steps=100,
        only_fst_surprising_singal=True,
        only_pos_rollouts=False,
        normalize=False,
        normalizer_params=dict(
            eps=0.01,
            default_clip_range=5
        ),
        coords_gen_params=dict(
            buffer_size=int(1e5),
        ),
        reset_model_below_n_pos_samples=20,
        use_switching_reward=True,
    )
    update_default_params(gnet_params, external_params.get('gnet_params', {}))
    update_default_params(gnet_params, continued_params.get('gnet_params', {}))
    update_default_params(gnet_params, kwargs.get('gnet_params', {}))

    assert 'network_params' in gnet_params

    gnets_fn = []
    gnets_params = []
    for i in range(len(tasks_specs)):
        gnets_fn.append([])
        gnets_params.append([])
        for j in range(len(tasks_specs)):
            params = deepcopy(gnet_params)
            params['task_from_id'] = i
            params['task_to_id'] = j
            params['scope'] = f'gnet_{i}_to_{j}'
            gnets_params[-1].append(params)
            gnets_fn[-1].append((Gnet, params))

    #########################
    # Prepare forward model #
    #########################

    forward_model_params = dict(
        env_spec=env_spec,
        buffer_size=int(1e6),
        lr=1e-4,
        hist_length=1,
        batch_size=64,
        network_params=dict(
            nL=[100]*9,
            net_type='forward_model.models:ForwardModelMLPStateDiff',
            activation='tensorflow.nn:tanh',
            layer_norm=False,
            scope='mlp'
        ),
        normalizer_params=None,
        train_steps=100,
        scope='forwardModel'
    )
    update_default_params(forward_model_params, external_params.get('forward_model_params', {}))
    update_default_params(forward_model_params, continued_params.get('forward_model_params', {}))
    update_default_params(forward_model_params, kwargs.get('forward_model_params', {}))

    forward_model_fn = (ForwardModel, forward_model_params)

    #########################
    # Prepare RolloutWorker #
    #########################

    rollout_worker_params = dict(
        surprise_std_scaling=3,
        discard_modules_buffer=True,
        seed=seed,
        forward_model_burnin_eps=50,
        resample_goal_every=5,
    )
    update_default_params(rollout_worker_params, external_params.get('rollout_worker_params', {}))
    update_default_params(rollout_worker_params, continued_params.get('rollout_worker_params', {}))
    update_default_params(rollout_worker_params, kwargs.get('rollout_worker_params', {}))

    rollout_worker_fn = (RolloutWorker, rollout_worker_params)

    #########################
    # Write params to file  #
    #########################

    inherit_from = get_parameter('inherit_from', params=external_params, default=None)
    inherit_from = get_parameter('inherit_from', params=continued_params, default=inherit_from)
    inherit_from = get_parameter('inherit_from', params=kwargs, default=inherit_from)

    params_path = get_parameter('params_path', params=external_params, default=None)
    params_path = get_parameter('params_path', params=continued_params, default=params_path)
    params_path = get_parameter('params_path', params=kwargs, default=params_path)

    params_prefix = get_parameter('params_prefix', params=external_params, default=None)
    params_prefix = get_parameter('params_prefix', params=continued_params, default=params_prefix)
    params_prefix = get_parameter('params_prefix', params=kwargs, default=params_prefix)

    max_env_steps = get_parameter('max_env_steps', params=external_params, default=None)
    max_env_steps = get_parameter('max_env_steps', params=continued_params, default=max_env_steps)
    max_env_steps = get_parameter('max_env_steps', params=kwargs, default=max_env_steps)

    render = get_parameter('render', params=external_params, default=None)
    render = get_parameter('render', params=continued_params, default=render)
    render = get_parameter('render', params=kwargs, default=render)

    num_worker = get_parameter('num_worker', params=external_params, default=None)
    num_worker = get_parameter('num_worker', params=continued_params, default=num_worker)
    num_worker = get_parameter('num_worker', params=kwargs, default=num_worker)

    eval_runs = get_parameter('eval_runs', params=external_params, default=None)
    eval_runs = get_parameter('eval_runs', params=continued_params, default=eval_runs)
    eval_runs = get_parameter('eval_runs', params=kwargs, default=eval_runs)

    env = get_parameter('env', params=external_params, default=None)
    env = get_parameter('env', params=continued_params, default=env)
    env = get_parameter('env', params=kwargs, default=env)

    restart_after = get_parameter('restart_after', params=kwargs, default=restart_after)

    json_string = get_parameter('json_string', params=external_params, default=None)
    json_string = get_parameter('json_string', params=continued_params, default=json_string)
    json_string = get_parameter('json_string', params=kwargs, default=json_string)

    experiment = get_parameter('experiment', params=external_params, default=None)
    experiment = get_parameter('experiment', params=continued_params, default=experiment)
    experiment = get_parameter('experiment', params=kwargs, default=experiment)

    store_params_every = get_parameter('store_params_every', params=external_params, default=None)
    store_params_every = get_parameter('store_params_every', params=continued_params, default=store_params_every)
    store_params_every = get_parameter('store_params_every', params=kwargs, default=store_params_every)

    params_cache_size = get_parameter('params_cache_size', params=external_params, default=None)
    params_cache_size = get_parameter('params_cache_size', params=continued_params, default=params_cache_size)
    params_cache_size = get_parameter('params_cache_size', params=kwargs, default=params_cache_size)

    use_surprise = get_parameter('use_surprise', params=external_params, default=True)
    use_surprise = get_parameter('use_surprise', params=continued_params, default=use_surprise)
    use_surprise = get_parameter('use_surprise', params=kwargs, default=use_surprise)

    params = dict(
        inherit_from=inherit_from,
        basic_params_path=basic_params_path,
        params_path=params_path,
        params_prefix=params_prefix,
        store_params_every=store_params_every,
        params_cache_size=params_cache_size,
        use_surprise=use_surprise,
        seed=seed,
        clean=clean,
        jobdir=jobdir,
        max_env_steps=max_env_steps,
        render=render,
        num_worker=num_worker,
        eval_runs=eval_runs,
        env=env,
        restart_after=restart_after,
        json_string=json_string,
        experiment=experiment,
        env_spec=env_spec,
        env_params=env_params,
        tasks_specs=tasks_specs,
        policy_params=policy_params,
        policies_params=policies_params,
        task_selector_params=task_selector_params,
        task_planner_params=task_planner_params,
        gnet_params=gnet_params,
        gnets_params=gnets_params,
        forward_model_params=forward_model_params,
        rollout_worker_params=rollout_worker_params,
    )

    assert np.all([k in params for k in kwargs.keys()]), [k for k in kwargs.keys() if not k in params]

    with open(os.path.join(jobdir, 'basic_params.json'), 'w') as f:
        json.dump(params, f)

    return params, {'env_fn': env_fn, 'tasks_fn': tasks_fn, 'policies_fn': policies_fn,
            'gnets_fn': gnets_fn, 'task_selector_fn': task_selector_fn, 'task_planner_fn': task_planner_fn,
            'forward_model_fn': forward_model_fn, 'rollout_worker_fn': rollout_worker_fn}, summary_writer