コード例 #1
0
ファイル: gru.py プロジェクト: xlnwel/d2rl
 def call(self, x, state, mask, additional_input=[]):
     xs = [x] + additional_input
     mask = tf.expand_dims(mask, axis=-1)
     assert_rank(xs + [mask], 3)
     if not self._state_mask:
         # mask out inputs
         for i, v in enumerate(xs):
             xs[i] *= tf.cast(mask, v.dtype)
     x = tf.concat(xs, axis=-1) if len(xs) > 1 else xs[0]
     if not mask.dtype.is_compatible_with(global_policy().compute_dtype):
         mask = tf.cast(mask, global_policy().compute_dtype)
     x = self._rnn((x, mask), initial_state=state)
     x, state = x[0], GRUState(x[1])
     return x, state
コード例 #2
0
 def construct(x, default_shape):
     """
     By default, default_shape is add before the first dimension of s.
     There are two ways to omit/change default_shape:
         1. to set s = None to omit default_shape, resulting in s = ()
         2. to pass an additional argument to x to override default_shape.
            Note that if s = None, this default_shape will be omitted anyway
     """
     if isinstance(x, tf.TensorSpec):
         return x
     elif isinstance(x, (list, tuple)):
         if hasattr(x, '_fields') or (len(x) > 1
                                      and isinstance(x[1], tuple)):
             # x is a list/tuple of TensorSpecs, recursively construct them
             return get_TensorSpecs(x,
                                    sequential=sequential,
                                    batch_size=batch_size)
         if len(x) == 1:
             s = x
             d = mixed_precision.global_policy().compute_dtype
             n = None
         elif len(x) == 2:
             s, d = x
             n = None
         elif len(x) == 3:
             s, d, n = x
         elif len(x) == 4:
             s, d, n, default_shape = x
         else:
             raise ValueError(f'Unknown form x: {x}')
         s = () if s is None else default_shape + list(s)
         return tf.TensorSpec(shape=s, dtype=d, name=n)
     else:
         raise ValueError(f'Unknown form x: {x}')
コード例 #3
0
ファイル: nn.py プロジェクト: xlnwel/d2rl
    def action(self,
               x,
               state,
               mask,
               prev_action,
               prev_reward=None,
               evaluation=False,
               epsilon=0,
               temp=1.,
               return_stats=False,
               return_eval_stats=False):
        assert x.shape.ndims in (2, 4), x.shape

        mask = tf.cast(tf.reshape(mask, (-1, 1)),
                       global_policy().compute_dtype)
        state = tf.nest.map_structure(lambda x: x * mask, state)
        prev_action = prev_action * mask

        embed = self.encoder(x)
        embed = tf.squeeze(embed, 1)
        state = self.rssm.post_step(state, prev_action, embed)
        feature = self.rssm.get_feat(state)
        if evaluation:
            action = self.actor(feature)[0].mode()
        else:
            action = self.actor(feature)[0].sample()
            action = tf.clip_by_value(
                tfd.Normal(action, epsilon).sample(), -1, 1)

        return (action, {}), state
コード例 #4
0
ファイル: optimizer.py プロジェクト: xlnwel/d2rl
 def __init__(self,
              name,
              models,
              lr,
              clip_norm=None,
              weight_decay=None,
              l2_reg=None,
              wdpattern=r'.*',
              scales=None,
              return_grads=False,
              **kwargs):
     self._models = models if isinstance(models,
                                         (list, tuple)) else [models]
     self._clip_norm = clip_norm
     self._weight_decay = weight_decay
     self._l2_reg = l2_reg
     self._wdpattern = wdpattern
     if scales is not None:
         assert isinstance(scales, (list, tuple)), scales
         assert len(scales) == len(self._models), (len(scales),
                                                   len(self._models))
     self._scales = scales
     self._opt = select_optimizer(name)(lr, **kwargs)
     self._return_grads = return_grads
     # useful for mixed precision training on GPUs to
     # avoid numerical underflow caused by using float16 gradients
     prec_policy = prec.global_policy()
     self._mpt = prec_policy.compute_dtype != prec_policy.variable_dtype
     if self._mpt:
         logger.info('Mixed precision training will be performed')
         self._opt = prec.LossScaleOptimizer(self._opt)
     # we do not initialize variables here, as models may not be initialized at this point
     self._variables = None
コード例 #5
0
ファイル: base.py プロジェクト: xlnwel/d2rl
def get_data_format(*, env, replay_config, agent_config, model, **kwargs):
    is_per = replay_config['replay_type'].endswith('per')
    store_state = agent_config['store_state']
    sample_size = agent_config['sample_size']
    obs_dtype = tf.uint8 if len(env.obs_shape) == 3 else tf.float32
    data_format = dict(
        obs=((None, sample_size + 1, *env.obs_shape), obs_dtype),
        action=((None, sample_size + 1, *env.action_shape), tf.int32),
        reward=((None, sample_size), tf.float32),
        mu=((None, sample_size + 1), tf.float32),
        discount=((None, sample_size), tf.float32),
        mask=((None, sample_size + 1), tf.float32),
    )
    if is_per:
        data_format['idxes'] = ((None), tf.int32)
        if replay_config.get('use_is_ratio'):
            data_format['IS_ratio'] = ((None, ), tf.float32)
    if store_state:
        state_size = model.state_size
        from tensorflow.keras.mixed_precision import global_policy
        state_dtype = global_policy().compute_dtype
        data_format.update({
            k: ((None, v), state_dtype)
            for k, v in state_size._asdict().items()
        })

    return data_format
コード例 #6
0
ファイル: gru.py プロジェクト: xlnwel/d2rl
 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
     state_size = self.state_size
     if inputs is not None:
         assert batch_size is None or batch_size == tf.shape(inputs)[0]
         batch_size = tf.shape(inputs)[0]
     if dtype is None:
         dtype = global_policy().compute_dtype
     return GRUState(h=tf.zeros([batch_size, state_size[0]], dtype))
コード例 #7
0
ファイル: nn.py プロジェクト: xlnwel/d2rl
    def call(self, x):
        x = convert_obs(x, [-.5, .5], global_policy().compute_dtype)
        B, T = x.shape[:2] if len(x.shape) == 5 else (x.shape[0], 1)
        x = tf.reshape(x, (-1, *x.shape[-3:]))
        x = self._conv1(x)
        x = self._conv2(x)
        x = self._conv3(x)
        x = self._conv4(x)
        x = tf.reshape(x, (B, T, tf.reduce_prod(tf.shape(x)[-3:])))

        return x
コード例 #8
0
ファイル: history.py プロジェクト: Vivek305/fastestimator
    def _get_features_in_use(self) -> List[Dict[str, str]]:
        """Determine which interesting FE features are being used by the current training.

        Returns:
            A list of entries which can be written into the 'features' db table.
        """
        features = []
        if sys.modules['fastestimator'].fe_deterministic_seed is not None:
            features.append({'feature': 'Deterministic', 'fk': self.pk})
        if any([len(mode_dict) > 1 for mode_dict in self.system.pipeline.data.values()]):
            features.append({'feature': 'MultiDataset', 'fk': self.pk})
        if mixed_precision.global_policy().compute_dtype == 'float16':
            features.append({'feature': 'MixedPrecision', 'fk': self.pk})
        return features
コード例 #9
0
 def call(self, x, training=False):
     x = convert_obs(x, self._obs_range, global_policy().compute_dtype)
     if self._time_distributed:
         t = x.shape[1]
         x = tf.reshape(x, [-1, *x.shape[2:]])
     x = super().call(x, training=training)
     self.cnn_out = x
     x = self._flat(x)
     if self._time_distributed:
         x = tf.reshape(x, [-1, t, *x.shape[1:]])
     if self.out_size:
         x = self._dense(x)
     if self._deter_stoch:
         self.state = self._ds_layer.state
     return x
コード例 #10
0
ファイル: nn.py プロジェクト: xlnwel/d2rl
 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
     if inputs is not None:
         assert batch_size is None or batch_size == tf.shape(inputs)[0]
         batch_size = tf.shape(inputs)[0]
     assert batch_size is not None
     if dtype is None:
         dtype = global_policy().compute_dtype
     return RSSMState(mean=tf.zeros([batch_size, self._stoch_size],
                                    dtype=dtype),
                      std=tf.zeros([batch_size, self._stoch_size],
                                   dtype=dtype),
                      stoch=tf.zeros([batch_size, self._stoch_size],
                                     dtype=dtype),
                      deter=self._cell.get_initial_state(
                          inputs, batch_size, dtype))
コード例 #11
0
 def call(self, x, training=True, return_cnn_out=False):
     x = convert_obs(x, self._obs_range, global_policy().compute_dtype)
     if self._time_distributed:
         t = x.shape[1]
         x = tf.reshape(x, [-1, *x.shape[2:]])
     for l in self._convs:
         x = l(x)
     x = self._out_act(x)
     if self._time_distributed:
         x = tf.reshape(x, [-1, t, *x.shape[1:]])
     z = self._flat(x)
     if self.out_size:
         z = self._dense(z)
     if return_cnn_out:
         return z, x
     else:
         return z
コード例 #12
0
def convert_input_precision(tensor: Tensor) -> Tensor:
    """
        Adjust the input data precision based of environment precision.

        Args:
            tensor: The input value.

        Returns:
            The precision adjusted data(16 bit for mixed precision, 32 bit otherwise).

    """
    precision = 'float32'

    if mixed_precision.global_policy().compute_dtype == 'float16':
        precision = 'float16'

    return cast(tensor, precision)
コード例 #13
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def main(env_config, model_config, agent_config, replay_config):
    silence_tf_logs()
    configure_gpu()
    configure_precision(agent_config['precision'])

    use_ray = env_config.get('n_workers', 0) > 1
    if use_ray:
        import ray
        ray.init()
        sigint_shutdown_ray()

    env = create_env(env_config, make_env, force_envvec=True)
    eval_env_config = env_config.copy()
    eval_env_config['n_envs'] = 1
    eval_env_config['n_workers'] = 1
    eval_env = create_env(eval_env_config, make_env)

    replay_config['dir'] = agent_config['root_dir'].replace('logs', 'data')
    replay = create_replay(replay_config)
    replay.load_data()
    dtype = global_policy().compute_dtype
    data_format = pkg.import_module(
        'agent', config=agent_config).get_data_format(
            env=env,
            batch_size=agent_config['batch_size'],
            sample_size=agent_config['sample_size'],
            dtype=dtype)
    process = functools.partial(process_with_env,
                                env=env,
                                obs_range=[-.5, .5],
                                one_hot_action=True,
                                dtype=dtype)
    dataset = Dataset(replay, data_format, process)

    create_model, Agent = pkg.import_agent(config=agent_config)
    models = create_model(model_config, env)

    agent = Agent(config=agent_config, models=models, dataset=dataset, env=env)

    agent.save_config(
        dict(env=env_config,
             model=model_config,
             agent=agent_config,
             replay=replay_config))

    train(agent, env, eval_env, replay)
コード例 #14
0
ファイル: dqn_actor.py プロジェクト: xlnwel/d2rl
        def __init__(self,
                    name, 
                    model_fn,
                    config, 
                    model_config,
                    env_config, 
                    replay_config):
            cpu_affinity('Learner')
            silence_tf_logs()
            configure_threads(config['n_cpus'], config['n_cpus'])
            configure_gpu()
            configure_precision(config['precision'])
            self._dtype = global_policy().compute_dtype

            self._envs_per_worker = env_config['n_envs']
            env_config['n_envs'] = 1
            env = create_env(env_config)
            assert env.obs_dtype == np.uint8, \
                f'Expect image observation of type uint8, but get {env.obs_dtype}'
            self._action_shape = env.action_shape
            self._action_dim = env.action_dim
            self._frame_skip = getattr(env, 'frame_skip', 1)

            self.models = Ensemble(
                model_fn=model_fn,
                config=model_config, 
                obs_shape=env.obs_shape,
                action_dim=env.action_dim, 
                is_action_discrete=env.is_action_discrete
            )

            super().__init__(
                name=name, 
                config=config, 
                models=self.models,
                dataset=None,
                env=env)

            replay_config['dir'] = config['root_dir'].replace('logs', 'data')
            self.replay = create_replay(replay_config)
            data_format = get_data_format(env, replay_config)
            process = functools.partial(process_with_env, env=env)
            self.dataset = Dataset(self.replay, data_format, process, prefetch=10)

            self._env_step = self.env_step()
コード例 #15
0
ファイル: dqn_actor.py プロジェクト: xlnwel/d2rl
        def __init__(self,
                    name,
                    model_fn,
                    config,
                    model_config,
                    env_config):
            cpu_affinity('Actor')
            silence_tf_logs()
            configure_threads(1, 1)
            configure_gpu()
            configure_precision(config['precision'])
            self._dtype = global_policy().compute_dtype

            self._envs_per_worker = env_config['n_envs']
            env_config['n_envs'] = config['action_batch']
            env = create_env(env_config)
            assert self.env.obs_dtype == np.uint8, \
                f'Expect image observation of type uint8, but get {self.env.obs_dtype}'
            self._action_shape = self.env.action_shape
            self._action_dim = self.env.action_dim

            self.models = Ensemble(
                model_fn=model_fn,
                config=model_config, 
                obs_shape=self.env.obs_shape,
                action_dim=self.env.action_dim, 
                is_action_discrete=self.env.is_action_discrete
            )

            super().__init__(
                name=name, 
                config=config, 
                models=self.models,
                dataset=None,
                env=self.env)
            
            # cache for episodes
            self._cache = collections.defaultdict(list)

            # agent's state
            self._state = collections.defaultdict(lambda:
                self.rssm.get_initial_state(batch_size=1, dtype=self._dtype))
            self._prev_action = collections.defaultdict(lambda:
                tf.zeros((1, self._action_dim), self._dtype))
コード例 #16
0
 def preprocess(self, obs):
   dtype = prec.global_policy().compute_dtype
   obs = obs.copy()
   for key, value in obs.items():
     if key.startswith('log_'):
       continue
     if value.dtype == tf.int32:
       value = value.astype(dtype)
     if value.dtype == tf.uint8:
       value = value.astype(dtype) / 255.0 - 0.5
     obs[key] = value
   obs['reward'] = {
       'identity': tf.identity,
       'sign': tf.sign,
       'tanh': tf.tanh,
   }[self.config.clip_rewards](obs['reward'])
   obs['discount'] = 1.0 - obs['is_terminal'].astype(dtype)
   obs['discount'] *= self.config.discount
   return obs
コード例 #17
0
    def wrapper(self, *, name=None, config, models, env, **kwargs):
        """
        Args:
            name: Agent's name
            config: configuration for agent, 
                should be read from config.yaml
            models: a dict of models
            kwargs: optional arguments for each specific agent
        """
        """ For the basic configuration, see config.yaml in algo/*/ """
        config_attr(self, config)

        # name is used in stdout/stderr as the agent's identifier
        # while model_name is used for logging and checkpoint
        # e.g., all workers share the same name, but with differnt model_names
        self.name = name or config["algorithm"]
        self._model_name = self._model_name or 'baseline'

        self._dtype = global_policy().compute_dtype

        self.model = models
        # track models and optimizers for Checkpoint
        self._ckpt_models = {}
        for name_, model in models.items():
            setattr(self, name_, model)
            if isinstance(model, tf.Module) or isinstance(model, tf.Variable):
                self._ckpt_models[name_] = model

        self._env_step = tf.Variable(0, trainable=False, dtype=tf.int64)
        self._train_step = tf.Variable(0, trainable=False, dtype=tf.int64)
        self.env_step = 0
        self.train_step = 0
        if config.get('writer', True):
            self._writer = setup_tensorboard(self._root_dir, self._model_name)
            tf.summary.experimental.set_step(0)

        # Agent initialization
        init_fn(self, env=env, **kwargs)

        # save optimizers
        for k, v in vars(self).items():
            if isinstance(v, Optimizer):
                self._ckpt_models[k[1:]] = v
        logger.info(f'ckpt models: {self._ckpt_models}')

        self.print_construction_complete()

        if config.get('display_var', True):
            display_model_var_info(self._ckpt_models)

        if config.get('save_code', True):
            save_code(self._root_dir, self._model_name)

        self._ckpt, self._ckpt_path, self._ckpt_manager = \
            setup_checkpoint(self._ckpt_models, self._root_dir,
                            self._model_name, self._env_step, self._train_step)

        self.restore()

        # to save stats to files, specify `logger: True` in config.yaml
        self._logger = setup_logger(
            config.get('logger', True) and self._root_dir, self._model_name)
コード例 #18
0
 def __init__(self, logits=None, probs=None):
     self._dist = tfd.Categorical(logits=logits, probs=probs)
     self._num_classes = self.mean().shape[-1]
     self._dtype = global_policy().compute_dtype