Exemple #1
0
    def __init__(self, name, env, ob_env_name, primitives, config):
        # args
        self.name = name
        self._config = config

        # training
        self._hid_size = config.meta_hid_size
        self._num_hid_layers = config.meta_num_hid_layers
        self._activation = ops.activation(config.meta_activation)

        # properties
        primitive_env = make_env(ob_env_name, config)
        self._ob_shape = primitive_env.ob_shape
        self.ob_type = sorted(primitive_env.ob_type)
        if 'acc' in self._ob_shape:
            self._ob_shape.pop('acc')
            self.ob_type.remove('acc')
        primitive_env.close()

        self._env = env
        self._ob_space = np.sum(
            [np.prod(ob) for ob in self._ob_shape.values()])
        self.num_primitives = len(primitives)
        self.primitive_names = primitives

        if not config.meta_oracle:
            self._build()
Exemple #2
0
    def __init__(self, name, env, ob_env_name, config):
        # args
        self.name = name

        # training
        self._hid_size = config.rl_hid_size
        self._num_hid_layers = config.rl_num_hid_layers
        self._gaussian_fixed_var = config.rl_fixed_var
        self._activation = ops.activation(config.rl_activation)
        self._include_acc = config.primitive_include_acc

        # properties
        primitive_env = make_env(ob_env_name)
        self._ob_shape = primitive_env.ob_shape
        self.ob_type = sorted(primitive_env.ob_type)
        if not self._include_acc and 'acc' in self.ob_type:
            self._ob_shape.pop('acc')
            self.ob_type.remove('acc')

        self._env = env
        self._ob_space = np.sum(
            [np.prod(ob) for ob in self._ob_shape.values()])
        self._ac_space = primitive_env.action_space
        primitive_env.close()

        with tf.variable_scope(self.name):
            self._scope = tf.get_variable_scope().name
            self._build()
Exemple #3
0
    def __init__(self,
                 name,
                 env,
                 ob_env_name,
                 num_primitives,
                 trans_term_activation='softmax',
                 config=None):
        # configs
        self.term_activation = trans_term_activation
        self._config = config

        # args
        self.name = name
        self.env_name = self.name.split('.')[0]

        # training
        self._hid_size = config.trans_hid_size
        self._num_hid_layers = config.trans_num_hid_layers
        self._gaussian_fixed_var = config.trans_fixed_var
        self._activation = ops.activation(config.trans_activation)
        self._include_acc = config.trans_include_acc

        # properties
        primitive_env = make_env(ob_env_name, config)
        self._ob_shape = primitive_env.ob_shape
        self.ob_type = sorted(primitive_env.ob_type)
        self.primitive_env = primitive_env

        if not self._include_acc and 'acc' in self.ob_type:
            self._ob_shape.pop('acc')
            self.ob_type.remove('acc')

        self._env = env
        self._ob_space = np.sum(
            [np.prod(ob) for ob in self._ob_shape.values()])
        self._ac_space = env.action_space
        self._num_primitives = num_primitives

        with tf.variable_scope(self.name):
            self._scope = tf.get_variable_scope().name
            self._build()
Exemple #4
0
    def __init__(self, name, env, ob_env_name, config=None):
        # configs
        self._config = config

        # args
        self.name = name
        self.env_name = self.name.split('-')[0]

        # training
        self._hid_size = config.primitive_hid_size
        self._num_hid_layers = config.primitive_num_hid_layers
        self._gaussian_fixed_var = config.primitive_fixed_var
        self._activation = ops.activation(config.primitive_activation)
        self._include_acc = config.primitive_include_acc

        # properties
        self.ob_env_name = ob_env_name
        primitive_env = make_env(ob_env_name, config)
        self.hard_coded = primitive_env.hard_coded
        self._ob_shape = primitive_env.ob_shape
        self.ob_type = sorted(primitive_env.ob_type)

        if not self._include_acc and 'acc' in self.ob_type:
            self._ob_shape.pop('acc')
            self.ob_type.remove('acc')

        self._env = env
        self._ob_space = np.sum(
            [np.prod(ob) for ob in self._ob_shape.values()])
        self._ac_space = primitive_env.action_space

        if config.primitive_use_term:
            self.primitive_env = primitive_env
        else:
            primitive_env.close()

        if not self.hard_coded:
            with tf.variable_scope(self.name):
                self._scope = tf.get_variable_scope().name
                self._build()
Exemple #5
0
def run(config):
    sess = U.single_threaded_session(gpu=False)
    sess.__enter__()

    rank = MPI.COMM_WORLD.Get_rank()
    is_chef = (rank == 0)

    workerseed = config.seed + 10000 * rank
    set_global_seeds(workerseed)

    if is_chef:
        logger.configure()
    else:
        logger.set_level(logger.DISABLED)
        config.render = False
        config.record = False

    env_name = config.env
    env = make_env(env_name, config)

    if is_chef and config.is_train:
        with open(osp.join(config.log_dir, "args.txt"), "a") as f:
            f.write("\nEnvironment argument:\n")
            for k in sorted(env.unwrapped._config.keys()):
                f.write("{}: {}\n".format(k, env.unwrapped._config[k]))

    networks = []

    # build models
    if config.hrl:
        assert config.primitive_envs is not None and config.primitive_paths is not None

        logger.info('====== Module list ======')
        num_primitives = len(config.primitive_envs)
        for primitive_env_name, primitive_path in zip(config.primitive_envs,
                                                      config.primitive_paths):
            logger.info('Env: {}, Dir: {}'.format(primitive_env_name,
                                                  primitive_path))

        meta_pi = MetaPolicy(name="%s/meta_pi" % env_name,
                             env=env,
                             ob_env_name=env_name,
                             primitives=config.primitive_envs,
                             config=config)

        meta_oldpi = MetaPolicy(name="%s/meta_oldpi" % env_name,
                                env=env,
                                ob_env_name=env_name,
                                primitives=config.primitive_envs,
                                config=config)

        primitive_pis = [
            PrimitivePolicy(name="%s/pi" % primitive_env_name,
                            env=env,
                            ob_env_name=primitive_env_name,
                            config=config)
            for primitive_env_name in config.primitive_envs
        ]

        trans_pis, trans_oldpis = None, None
        if config.use_trans:
            trans_pis = [
                TransitionPolicy(
                    name="%s/transition_pi" % primitive_env_name,
                    env=env,
                    ob_env_name=env_name
                    if config.trans_include_task_obs else primitive_env_name,
                    num_primitives=num_primitives,
                    trans_term_activation=config.trans_term_activation,
                    config=config)
                for primitive_env_name in config.primitive_envs
            ]
            trans_oldpis = [
                TransitionPolicy(
                    name="%s/transition_oldpi" % primitive_env_name,
                    env=env,
                    ob_env_name=env_name
                    if config.trans_include_task_obs else primitive_env_name,
                    num_primitives=num_primitives,
                    trans_term_activation=config.trans_term_activation,
                    config=config)
                for primitive_env_name in config.primitive_envs
            ]
            networks.extend(trans_pis)
            networks.extend(trans_oldpis)
        networks.append(meta_pi)
        networks.append(meta_oldpi)
        networks.extend(primitive_pis)

        # build proximity_predictor
        proximity_predictors = None
        if config.use_proximity_predictor:
            portion_start = [
                float(v) for v in config.proximity_use_traj_portion_start
            ]
            portion_end = [
                float(v) for v in config.proximity_use_traj_portion_end
            ]
            if len(portion_start) == 1:
                portion_start = portion_start * num_primitives
            if len(portion_end) == 1:
                portion_end = portion_end * num_primitives

            proximity_predictors = [
                ProximityPredictor(
                    name="%s/proximity_predictor" % primitive_env_name,
                    path=path,
                    env=env,
                    ob_env_name=primitive_env_name,  # make env for every primitive
                    use_traj_portion_end=portion_end,
                    use_traj_portion_start=portion_start,
                    is_train=config.is_train,
                    config=config
                ) for primitive_env_name, path, portion_start, portion_end in \
                zip(config.primitive_envs, config.primitive_paths, portion_start, portion_end)]
            networks.extend(proximity_predictors)

        # build trainer
        from rl.trainer import Trainer
        trainer = Trainer(env, meta_pi, meta_oldpi, proximity_predictors,
                          num_primitives, trans_pis, trans_oldpis, config)

        # build rollout
        rollout = rollouts.traj_segment_generator(
            # stochastic=config.is_train, config=config)
            env,
            meta_pi,
            primitive_pis,
            trans_pis,
            stochastic=True,
            config=config,
            proximity_predictors=proximity_predictors,
        )
    else:
        # build vanilla TRPO
        policy = MlpPolicy(env=env,
                           name="%s/pi" % env_name,
                           ob_env_name=env_name,
                           config=config)

        old_policy = MlpPolicy(env=env,
                               name="%s/oldpi" % env_name,
                               ob_env_name=env_name,
                               config=config)
        networks.append(policy)
        networks.append(old_policy)

        # build trainer
        from rl.trainer_rl import RLTrainer
        trainer = RLTrainer(env, policy, old_policy, config)
        # build rollout
        rollout = rollouts.traj_segment_generator_rl(
            # env, policy, stochastic=config.is_train, config=config)
            env,
            policy,
            stochastic=not config.is_collect_state,
            config=config)

    # initialize models
    def load_model(load_model_path, var_list=None):
        if os.path.isdir(load_model_path):
            ckpt_path = tf.train.latest_checkpoint(load_model_path)
        else:
            ckpt_path = load_model_path
        if ckpt_path:
            U.load_state(ckpt_path, var_list)
        return ckpt_path

    if config.load_meta_path is not None:
        var_list = meta_pi.get_variables() + meta_oldpi.get_variables()
        ckpt_path = load_model(config.load_meta_path, var_list)
        logger.info(
            '* Load the meta policy from checkpoint: {}'.format(ckpt_path))

    def tensor_description(var):
        description = '({} [{}])'.format(
            var.dtype.name, 'x'.join([str(size) for size in var.get_shape()]))
        return description

    var_list = []
    for network in networks:
        var_list += network.get_variables()
    if is_chef:
        for var in var_list:
            logger.info('{} {}'.format(var.name, tensor_description(var)))

    if config.load_model_path is not None:
        # Load all the network
        if config.is_train:
            ckpt_path = load_model(config.load_model_path)
            if config.hrl:
                load_buffers(proximity_predictors, ckpt_path)
        else:
            ckpt_path = load_model(config.load_model_path, var_list)
        logger.info(
            '* Load all policies from checkpoint: {}'.format(ckpt_path))
    elif config.is_train:
        ckpt_path = tf.train.latest_checkpoint(config.log_dir)
        if config.hrl:
            if ckpt_path:
                ckpt_path = load_model(ckpt_path)
                load_buffers(proximity_predictors, ckpt_path)
            else:
                # Only load the primitives
                for (primitive_name,
                     primitive_pi) in zip(config.primitive_paths,
                                          primitive_pis):
                    var_list = primitive_pi.get_variables()
                    if var_list:
                        primitive_path = osp.expanduser(
                            osp.join(config.primitive_dir, primitive_name))
                        ckpt_path = load_model(primitive_path, var_list)
                        logger.info("* Load module ({}) from {}".format(
                            primitive_name, ckpt_path))
                    else:
                        logger.info(
                            "* Hard-coded module ({})".format(primitive_name))
            logger.info("Loading modules is done.")
        else:
            if ckpt_path:
                ckpt_path = load_model(ckpt_path)
    else:
        logger.info('[!] Checkpoint for evaluation is not provided.')
        ckpt_path = load_model(config.log_dir, var_list)
        logger.info(
            "* Load all policies from checkpoint: {}".format(ckpt_path))

    if config.is_train:
        trainer.train(rollout)
    else:
        if config.evaluate_proximity_predictor:
            trainer.evaluate_proximity_predictor(var_list)
        else:
            trainer.evaluate(rollout, ckpt_num=ckpt_path.split('/')[-1])

    env.close()
Exemple #6
0
    def __init__(self,
                 name,
                 path,
                 env,
                 ob_env_name,
                 is_train=True,
                 use_traj_portion_start=0.0,
                 use_traj_portion_end=1.0,
                 config=None):
        self._scope = 'proximity_predictor/' + name
        self.env_name = name.split('.')[0]
        self._config = config

        # make primitive env for observation
        self._env = make_env(ob_env_name, config)
        self._include_acc = config.proximity_include_acc
        self._ob_shape = self._env.unwrapped.ob_shape
        self.ob_type = sorted(self._env.unwrapped.ob_type)
        if not self._include_acc and 'acc' in self.ob_type:
            self._ob_shape.pop('acc')
            self.ob_type.remove('acc')

        self.obs_norm = config.proximity_obs_norm
        self.observation_shape = np.sum(
            [np.prod(ob) for ob in self._ob_shape.values()])

        # replay buffers
        self.fail_buffer = Replay(max_size=config.proximity_replay_size,
                                  name='fail_buffer')
        self.success_buffer = Replay(max_size=config.proximity_replay_size,
                                     name='success_buffer')

        # build the architecture
        self._num_hidden_layer = config.proximity_num_hid_layers
        self._hidden_size = config.proximity_hid_size
        self._activation_fn = activation(config.proximity_activation_fn)
        self._build_ph()

        logger.info('===== Proximity_predictor for {} ====='.format(
            self._scope))
        # load collected states
        if is_train or config.evaluate_proximity_predictor:
            state_file_path = osp.join(config.primitive_dir,
                                       path.split('/')[0], 'state')
            logger.info('Search state files from: {}'.format(
                config.primitive_dir))
            state_file_list = glob.glob(osp.join(state_file_path, '*.hdf5'))
            logger.info('Candidate state files: {}'.format(' '.join(
                [f.split('/')[-1] for f in state_file_list])))
            state_file = {}
            try:
                logger.info('Use state files: {}'.format(
                    state_file_list[0].split('/')[-1]))
                state_file = h5py.File(state_file_list[0], 'r')
            except:
                logger.warn(
                    "No collected state hdf5 file is located at {}".format(
                        state_file_path))
            logger.info('Use traj portion: {} to {}'.format(
                use_traj_portion_start, use_traj_portion_end))

            if self._config.proximity_keep_collected_obs:
                add_obs = self.success_buffer.add_collected_obs
            else:
                add_obs = self.success_buffer.add

            for k in list(state_file.keys()):
                traj_state = state_file[k]['obs'].value
                start_idx = int(traj_state.shape[0] * use_traj_portion_start)
                end_idx = int(traj_state.shape[0] * use_traj_portion_end)
                try:
                    if state_file[k]['success'].value == 1:
                        traj_state = traj_state[start_idx:end_idx]
                    else:
                        continue
                except:
                    traj_state = traj_state[start_idx:end_idx]
                for t in range(traj_state.shape[0]):
                    ob = traj_state[t][:self.observation_shape]
                    # [ob, label]
                    add_obs(np.concatenate((ob, [1.0]), axis=0))

            # shape [num_state, dim_state]
            logger.info('Size of collected state: {}'.format(
                self.success_buffer.size()))
            logger.info('Average of collected state: {}'.format(
                np.mean(self.success_buffer.list(), axis=0)))

        # build graph
        fail_logits, fail_target_value, success_logits, success_target_value = \
            self._build_graph(self.fail_obs_ph, self.success_obs_ph, reuse=False)

        # compute prob
        fake_prob = tf.reduce_mean(fail_logits)  # should go to 0
        real_prob = tf.reduce_mean(success_logits)  # should go to 1

        # compute loss
        if config.proximity_loss_type == 'lsgan':
            self.fake_loss = tf.reduce_mean(
                (fail_logits - fail_target_value)**2)
            self.real_loss = tf.reduce_mean(
                (success_logits - success_target_value)**2)
        elif config.proximity_loss_type == 'wgan':
            self.fake_loss = tf.reduce_mean(
                tf.abs(fail_logits - fail_target_value))
            self.real_loss = tf.reduce_mean(
                tf.abs(success_logits - success_target_value))

        # loss + accuracy terms
        self.total_loss = self.fake_loss + self.real_loss
        self.losses = {
            "fake_loss": self.fake_loss,
            "real_loss": self.real_loss,
            "fake_prob": fake_prob,
            "real_prob": real_prob,
            "total_loss": self.total_loss
        }

        # predict proximity
        self._proximity_op = tf.clip_by_value(success_logits, 0, 1)[:, 0]