コード例 #1
0
ファイル: test_dict_space.py プロジェクト: Mee321/HAPG_exp
    def test_dict_space(self):
        ext.set_seed(0)

        # A dummy dict env
        dummy_env = DummyDictEnv()
        dummy_act = dummy_env.action_space
        dummy_act_sample = dummy_act.sample()

        # A dummy dict env wrapped by garage.tf
        tf_env = TfEnv(dummy_env)
        tf_act = tf_env.action_space
        tf_obs = tf_env.observation_space

        # flat_dim
        assert tf_act.flat_dim == tf_act.flatten(dummy_act_sample).shape[-1]

        # flat_dim_with_keys
        assert tf_obs.flat_dim == tf_obs.flat_dim_with_keys(
            iter(["achieved_goal", "desired_goal", "observation"]))

        # un/flatten
        assert tf_act.unflatten(
            tf_act.flatten(dummy_act_sample)) == dummy_act_sample

        # un/flatten_n
        samples = [dummy_act.sample() for _ in range(10)]
        assert tf_act.unflatten_n(tf_act.flatten_n(samples)) == samples

        # un/flatten_with_keys
        assert tf_act.unflatten_with_keys(
            tf_act.flatten_with_keys(dummy_act_sample, iter(["action"])),
            iter(["action"]))
コード例 #2
0
ファイル: test_benchmark_ddpg.py プロジェクト: gntoni/garage
def run_garage(env, seed, log_dir):
    """
    Create garage model and training.

    Replace the ddpg with the algorithm you want to run.

    :param env: Environment of the task.
    :param seed: Random seed for the trail.
    :param log_dir: Log dir path.
    :return:
    """
    ext.set_seed(seed)

    with tf.Graph().as_default():
        # Set up params for ddpg
        action_noise = OUStrategy(env, sigma=params["sigma"])

        actor_net = ContinuousMLPPolicy(
            env_spec=env,
            name="Actor",
            hidden_sizes=params["actor_hidden_sizes"],
            hidden_nonlinearity=tf.nn.relu,
            output_nonlinearity=tf.nn.tanh)

        critic_net = ContinuousMLPQFunction(
            env_spec=env,
            name="Critic",
            hidden_sizes=params["critic_hidden_sizes"],
            hidden_nonlinearity=tf.nn.relu)

        ddpg = DDPG(env,
                    actor=actor_net,
                    critic=critic_net,
                    actor_lr=params["actor_lr"],
                    critic_lr=params["critic_lr"],
                    plot=False,
                    target_update_tau=params["tau"],
                    n_epochs=params["n_epochs"],
                    n_epoch_cycles=params["n_epoch_cycles"],
                    n_rollout_steps=params["n_rollout_steps"],
                    n_train_steps=params["n_train_steps"],
                    discount=params["discount"],
                    replay_buffer_size=params["replay_buffer_size"],
                    min_buffer_size=int(1e4),
                    exploration_strategy=action_noise,
                    actor_optimizer=tf.train.AdamOptimizer,
                    critic_optimizer=tf.train.AdamOptimizer)

        # Set up logger since we are not using run_experiment
        tabular_log_file = osp.join(log_dir, "progress.csv")
        tensorboard_log_dir = osp.join(log_dir, "progress")
        garage_logger.add_tabular_output(tabular_log_file)
        garage_logger.set_tensorboard_dir(tensorboard_log_dir)

        ddpg.train()

        garage_logger.remove_tabular_output(tabular_log_file)

        return tabular_log_file
コード例 #3
0
def run_garage(env, seed, log_dir):
    """
    Create garage model and training.

    Replace the trpo with the algorithm you want to run.

    :param env: Environment of the task.
    :param seed: Random seed for the trail.
    :param log_dir: Log dir path.
    :return:import baselines.common.tf_util as U
    """
    ext.set_seed(seed)

    with tf.Graph().as_default():
        env = TfEnv(normalize(env))

        policy = GaussianMLPPolicy(
            name="policy",
            env_spec=env.spec,
            hidden_sizes=(32, 32),
            hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None,
        )

        baseline = GaussianMLPBaseline(
            env_spec=env.spec,
            regressor_args=dict(
                hidden_sizes=(32, 32),
                use_trust_region=True,
            ),
        )

        algo = TRPO(
            env=env,
            policy=policy,
            baseline=baseline,
            batch_size=1024,
            max_path_length=100,
            n_itr=976,
            discount=0.99,
            gae_lambda=0.98,
            clip_range=0.1,
            policy_ent_coeff=0.0,
            plot=False,
        )

        # Set up logger since we are not using run_experiment
        tabular_log_file = osp.join(log_dir, "progress.csv")
        garage_logger.add_tabular_output(tabular_log_file)
        garage_logger.set_tensorboard_dir(log_dir)

        algo.train()

        garage_logger.remove_tabular_output(tabular_log_file)

        return tabular_log_file
コード例 #4
0
def _worker_set_seed(_, seed):
    logger.log("Setting seed to %d" % seed)
    ext.set_seed(seed)
コード例 #5
0
def run_garage(env, seed, log_dir):
    """
    Create garage model and training.

    Replace the ddpg with the algorithm you want to run.

    :param env: Environment of the task.
    :param seed: Random seed for the trail.
    :param log_dir: Log dir path.
    :return:
    """
    ext.set_seed(seed)

    with tf.Graph().as_default():
        env = TfEnv(env)
        # Set up params for ddpg
        action_noise = OUStrategy(env.spec, sigma=params["sigma"])

        policy = ContinuousMLPPolicy(
            env_spec=env.spec,
            name="Policy",
            hidden_sizes=params["policy_hidden_sizes"],
            hidden_nonlinearity=tf.nn.relu,
            output_nonlinearity=tf.nn.tanh)

        qf = ContinuousMLPQFunction(
            env_spec=env.spec,
            name="QFunction",
            hidden_sizes=params["qf_hidden_sizes"],
            hidden_nonlinearity=tf.nn.relu)

        replay_buffer = SimpleReplayBuffer(
            env_spec=env.spec,
            size_in_transitions=params["replay_buffer_size"],
            time_horizon=params["n_rollout_steps"])

        ddpg = DDPG(
            env,
            policy=policy,
            qf=qf,
            replay_buffer=replay_buffer,
            policy_lr=params["policy_lr"],
            qf_lr=params["qf_lr"],
            plot=False,
            target_update_tau=params["tau"],
            n_epochs=params["n_epochs"],
            n_epoch_cycles=params["n_epoch_cycles"],
            max_path_length=params["n_rollout_steps"],
            n_train_steps=params["n_train_steps"],
            discount=params["discount"],
            min_buffer_size=int(1e4),
            exploration_strategy=action_noise,
            policy_optimizer=tf.train.AdamOptimizer,
            qf_optimizer=tf.train.AdamOptimizer)

        # Set up logger since we are not using run_experiment
        tabular_log_file = osp.join(log_dir, "progress.csv")
        tensorboard_log_dir = osp.join(log_dir, "progress")
        garage_logger.add_tabular_output(tabular_log_file)
        garage_logger.set_tensorboard_dir(tensorboard_log_dir)

        ddpg.train()

        garage_logger.remove_tabular_output(tabular_log_file)

        return tabular_log_file
コード例 #6
0
ファイル: run_experiment.py プロジェクト: vincentzhang/garage
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=("Number of parallel workers to perform rollouts. "
              "0 => don't start any workers"))
    parser.add_argument(
        '--exp_name',
        type=str,
        default=default_exp_name,
        help='Name of the experiment.')
    parser.add_argument(
        '--log_dir',
        type=str,
        default=None,
        help='Path to save the log and iteration snapshot.')
    parser.add_argument(
        '--snapshot_mode',
        type=str,
        default='all',
        help='Mode to save the snapshot. Can be either "all" '
        '(all iterations will be saved), "last" (only '
        'the last iteration will be saved), "gap" (every'
        '`snapshot_gap` iterations are saved), or "none" '
        '(do not save snapshots)')
    parser.add_argument(
        '--snapshot_gap',
        type=int,
        default=1,
        help='Gap between snapshot iterations.')
    parser.add_argument(
        '--tabular_log_file',
        type=str,
        default='progress.csv',
        help='Name of the tabular log file (in csv).')
    parser.add_argument(
        '--text_log_file',
        type=str,
        default='debug.log',
        help='Name of the text log file (in pure text).')
    parser.add_argument(
        '--tensorboard_step_key',
        type=str,
        default=None,
        help=("Name of the step key in tensorboard_summary."))
    parser.add_argument(
        '--params_log_file',
        type=str,
        default='params.json',
        help='Name of the parameter log file (in json).')
    parser.add_argument(
        '--variant_log_file',
        type=str,
        default='variant.json',
        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument(
        '--plot',
        type=ast.literal_eval,
        default=False,
        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help='Print only the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument(
        '--args_data', type=str, help='Pickled data for objects')
    parser.add_argument(
        '--variant_data',
        type=str,
        help='Pickled data for variant configuration')
    parser.add_argument(
        '--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    # SIGINT is blocked for all processes created in parallel_sampler to avoid
    # the creation of sleeping and zombie processes.
    #
    # If the user interrupts run_experiment, there's a chance some processes
    # won't die due to a dead lock condition where one of the children in the
    # parallel sampler exits without releasing a lock once after it catches
    # SIGINT.
    #
    # Later the parent tries to acquire the same lock to proceed with his
    # cleanup, but it remains sleeping waiting for the lock to be released.
    # In the meantime, all the process in parallel sampler remain in the zombie
    # state since the parent cannot proceed with their clean up.
    with mask_signals([signal.SIGINT]):
        if args.n_parallel > 0:
            parallel_sampler.initialize(n_parallel=args.n_parallel)
            if args.seed is not None:
                parallel_sampler.set_seed(args.seed)

    if not args.plot:
        garage.plotter.Plotter.disable()
        garage.tf.plotter.Plotter.disable()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.set_tensorboard_dir(log_dir)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.set_tensorboard_step_key(args.tensorboard_step_key)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            try:
                method_call(variant_data)
            except BaseException:
                children = garage.plotter.Plotter.get_plotters()
                children += garage.tf.plotter.Plotter.get_plotters()
                if args.n_parallel > 0:
                    children += [parallel_sampler]
                child_proc_shutdown(children)
                raise
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
コード例 #7
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=("Number of parallel workers to perform rollouts. "
              "0 => don't start any workers"))
    parser.add_argument(
        '--exp_name',
        type=str,
        default=default_exp_name,
        help='Name of the experiment.')
    parser.add_argument(
        '--log_dir',
        type=str,
        default=None,
        help='Path to save the log and iteration snapshot.')
    parser.add_argument(
        '--snapshot_mode',
        type=str,
        default='all',
        help='Mode to save the snapshot. Can be either "all" '
        '(all iterations will be saved), "last" (only '
        'the last iteration will be saved), "gap" (every'
        '`snapshot_gap` iterations are saved), or "none" '
        '(do not save snapshots)')
    parser.add_argument(
        '--snapshot_gap',
        type=int,
        default=1,
        help='Gap between snapshot iterations.')
    parser.add_argument(
        '--tabular_log_file',
        type=str,
        default='progress.csv',
        help='Name of the tabular log file (in csv).')
    parser.add_argument(
        '--text_log_file',
        type=str,
        default='debug.log',
        help='Name of the text log file (in pure text).')
    parser.add_argument(
        '--tensorboard_step_key',
        type=str,
        default=None,
        help=("Name of the step key in tensorboard_summary."))
    parser.add_argument(
        '--params_log_file',
        type=str,
        default='params.json',
        help='Name of the parameter log file (in json).')
    parser.add_argument(
        '--variant_log_file',
        type=str,
        default='variant.json',
        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument(
        '--plot',
        type=ast.literal_eval,
        default=False,
        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help='Print only the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument(
        '--args_data', type=str, help='Pickled data for stub objects')
    parser.add_argument(
        '--variant_data',
        type=str,
        help='Pickled data for variant configuration')
    parser.add_argument(
        '--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    assert (os.environ.get("JOBLIB_START_METHOD", None) == "forkserver")
    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from garage.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if not args.plot:
        garage.plotter.Plotter.disable()
        garage.tf.plotter.Plotter.disable()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.set_tensorboard_dir(log_dir)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.set_tensorboard_step_key(args.tensorboard_step_key)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            try:
                method_call(variant_data)
            except BaseException:
                if args.n_parallel > 0:
                    parallel_sampler.terminate()
                raise
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
コード例 #8
0
 def setUp(self):
     self.graph = tf.Graph()
     self.sess = tf.Session(graph=self.graph)
     self.sess.__enter__()
     logger.reset()
     ext.set_seed(1)
コード例 #9
0
ファイル: example_swimmer.py プロジェクト: stjordanis/HAPG
from garage.baselines import LinearFeatureBaseline
from garage.theano.baselines import GaussianMLPBaseline
from garage.baselines import ZeroBaseline
from garage.envs import normalize
from garage.envs.box2d import CartpoleEnv
from garage.envs.mujoco import SwimmerEnv
from garage.theano.algos.capg import CAPG
from garage.theano.envs import TheanoEnv
from garage.theano.policies import GaussianMLPPolicy
from garage.theano.baselines import GaussianMLPBaseline
from garage.misc.instrument import run_experiment
from garage.misc.ext import set_seed
import numpy as np

seed = np.random.randint(1, 10000)
set_seed(seed)
env_name = "Swimmer"
hidden_sizes = (32, 32)
env = TheanoEnv(normalize(SwimmerEnv()))
policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=hidden_sizes)
backup_policy = GaussianMLPPolicy(env.spec, hidden_sizes=hidden_sizes)
mix_policy = GaussianMLPPolicy(env.spec, hidden_sizes=hidden_sizes)
pos_eps_policy = GaussianMLPPolicy(env.spec, hidden_sizes=hidden_sizes)
neg_eps_policy = GaussianMLPPolicy(env.spec, hidden_sizes=hidden_sizes)

baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = CAPG(
    env=env,
    policy=policy,
    backup_policy=backup_policy,
コード例 #10
0
def run_task(plot=True, *_):
    set_seed(0)

    # Environment (train on light point mass)
    from sandbox.embed2learn.envs import point_mass_env
    from dm_control import suite
    suite._DOMAINS["embed_point_mass"] = point_mass_env
    env = TfEnv(
        normalize(
            MultiTaskEnv(
                task_env_cls=DmControlEnv,
                task_args=[["embed_point_mass", "light"]] * len(TASK_NAMES),
                task_kwargs=TASK_KWARGS)))

    # Latent space and embedding specs
    # TODO(gh/10): this should probably be done in Embedding or Algo
    latent_lb = np.zeros(LATENT_LENGTH, )
    latent_ub = np.ones(LATENT_LENGTH, )
    latent_space = Box(latent_lb, latent_ub)

    # trajectory space is (TRAJ_ENC_WINDOW, act_obs) where act_obs is a stacked
    # vector of flattened actions and observations
    act_lb, act_ub = env.action_space.bounds
    act_lb_flat = env.action_space.flatten(act_lb)
    act_ub_flat = env.action_space.flatten(act_ub)
    obs_lb, obs_ub = env.observation_space.bounds
    obs_lb_flat = env.observation_space.flatten(obs_lb)
    obs_ub_flat = env.observation_space.flatten(obs_ub)
    #act_obs_lb = np.concatenate([act_lb_flat, obs_lb_flat])
    #act_obs_ub = np.concatenate([act_ub_flat, obs_ub_flat])
    act_obs_lb = obs_lb_flat
    act_obs_ub = obs_ub_flat
    traj_lb = np.stack([act_obs_lb] * TRAJ_ENC_WINDOW)
    traj_ub = np.stack([act_obs_ub] * TRAJ_ENC_WINDOW)
    traj_space = Box(traj_lb, traj_ub)

    task_embed_spec = EmbeddingSpec(env.task_space, latent_space)
    traj_embed_spec = EmbeddingSpec(traj_space, latent_space)
    task_obs_space = concat_spaces(env.task_space, env.observation_space)
    env_spec_embed = EnvSpec(task_obs_space, env.action_space)

    # Embeddings
    task_embedding = GaussianMLPEmbedding(
        name="task_embedding",
        embedding_spec=task_embed_spec,
        hidden_sizes=(20, 20),
        adaptive_std=True,
        init_std=0.5,  # TODO was 100
        max_std=0.6,  # TODO find appropriate value
    )

    # TODO(): rename to inference_network
    traj_embedding = GaussianMLPEmbedding(
        name="traj_embedding",
        embedding_spec=traj_embed_spec,
        hidden_sizes=(20, 10),  # was the same size as policy in Karol's paper
        # adaptive_std=True,  # Must be True for embedding learning
        std_share_network=True,
        init_std=0.001,
    )

    # Multitask policy
    policy = GaussianMLPMultitaskPolicy(
        name="policy",
        env_spec=env.spec,
        task_space=env.task_space,
        embedding=task_embedding,
        hidden_sizes=(20, 10),
        adaptive_std=True,  # Must be True for embedding learning
        init_std=0.5,  # TODO was 100
    )

    baseline = LinearFeatureBaseline(env_spec=env_spec_embed)

    algo = TRPOTaskEmbedding(
        env=env,
        policy=policy,
        baseline=baseline,
        inference=traj_embedding,
        batch_size=4000,
        max_path_length=100,
        n_itr=500,
        discount=0.99,
        step_size=0.01,
        plot=plot,
        plot_warmup_itrs=20,
        policy_ent_coeff=0.0,  # 0.001,  #0.1,
        embedding_ent_coeff=0.0,  #0.1,
        inference_ce_ent_coeff=0.,  # 0.03,  #0.1,  # 0.1,
    )
    algo.train()
コード例 #11
0
def run_task(*_):
    set_seed(1)

    # Environment
    env = TfEnv(
        MultiTaskEnv(task_env_cls=PointEnv,
                     task_args=TASK_ARGS,
                     task_kwargs=TASK_KWARGS))

    # Latent space and embedding specs
    # TODO(gh/10): this should probably be done in Embedding or Algo
    latent_lb = np.zeros(LATENT_LENGTH, )
    latent_ub = np.ones(LATENT_LENGTH, )
    latent_space = Box(latent_lb, latent_ub)

    # trajectory space is (TRAJ_ENC_WINDOW, act_obs) where act_obs is a stacked
    # vector of flattened actions and observations
    act_lb, act_ub = env.action_space.bounds
    act_lb_flat = env.action_space.flatten(act_lb)
    act_ub_flat = env.action_space.flatten(act_ub)
    obs_lb, obs_ub = env.observation_space.bounds
    obs_lb_flat = env.observation_space.flatten(obs_lb)
    obs_ub_flat = env.observation_space.flatten(obs_ub)
    # act_obs_lb = np.concatenate([act_lb_flat, obs_lb_flat])
    # act_obs_ub = np.concatenate([act_ub_flat, obs_ub_flat])
    act_obs_lb = obs_lb_flat
    act_obs_ub = obs_ub_flat
    # act_obs_lb = act_lb_flat
    # act_obs_ub = act_ub_flat
    traj_lb = np.stack([act_obs_lb] * TRAJ_ENC_WINDOW)
    traj_ub = np.stack([act_obs_ub] * TRAJ_ENC_WINDOW)
    traj_space = Box(traj_lb, traj_ub)

    task_embed_spec = EmbeddingSpec(env.task_space, latent_space)
    traj_embed_spec = EmbeddingSpec(traj_space, latent_space)
    task_obs_space = concat_spaces(env.task_space, env.observation_space)
    env_spec_embed = EnvSpec(task_obs_space, env.action_space)

    # Embeddings
    task_embedding = GaussianMLPEmbedding(
        name="embedding",
        embedding_spec=task_embed_spec,
        hidden_sizes=(20, 20),
        std_share_network=True,
        init_std=3.0,  # 2.0
    )

    # TODO(): rename to inference_network
    traj_embedding = GaussianMLPEmbedding(
        name="inference",
        embedding_spec=traj_embed_spec,
        hidden_sizes=(20, 10),  # was the same size as policy in Karol's paper
        std_share_network=True,
    )

    # Multitask policy
    policy = GaussianMLPMultitaskPolicy(
        name="policy",
        env_spec=env.spec,
        task_space=env.task_space,
        embedding=task_embedding,
        hidden_sizes=(20, 10),
        std_share_network=True,  # Must be True for embedding learning
        init_std=6.0,  # 4.5 6.0
    )

    baseline = MultiTaskLinearFeatureBaseline(env_spec=env_spec_embed)

    algo = TRPOTaskEmbedding(
        env=env,
        policy=policy,
        baseline=baseline,
        inference=traj_embedding,
        batch_size=20000,
        max_path_length=50,
        n_itr=1000,
        discount=0.99,
        step_size=0.2,
        plot=False,
        policy_ent_coeff=1e-7,  # 1e-7
        embedding_ent_coeff=1e-3,  # 1e-3
        inference_ce_coeff=1e-7,  # 1e-7
        # kl_constraint=KLConstraint.SOFT,
        # optimizer_args=dict(max_penalty=1e9),
    )
    algo.train()