Пример #1
0
class BeholderHook(tf.train.SessionRunHook):
    """SessionRunHook implementation that runs Beholder every step.

    Convenient when using tf.train.MonitoredSession:
    ```python
    beholder_hook = BeholderHook(LOG_DIRECTORY)
    with MonitoredSession(..., hooks=[beholder_hook]) as sess:
      sess.run(train_op)
    ```
    """
    def __init__(self, logdir, list_of_np_ndarrays, frame):
        """Creates new Hook instance

        Args:
          logdir: Directory where Beholder should write data.
        """
        self._logdir = logdir
        self.beholder = None
        self.list_of_np_ndarrays = list_of_np_ndarrays
        self.frame = frame

    def begin(self):
        self.beholder = Beholder(self._logdir)

    def before_run(self, run_context):
        return tf.train.SessionRunArgs(fetches=self.list_of_np_ndarrays)

    def after_run(self, run_context, run_values):
        self.beholder.update(session=run_context.session,
                             arrays=run_values.results,
                             frame=self.frame)
Пример #2
0
class BeholderCallback(tf.keras.callbacks.Callback):
    def __init__(self, tensor, logdir, sess=None):
        self.visualizer = Beholder(logdir=logdir)
        self.sess = sess
        if sess is None:
            self.sess = K.get_session()
        self.tensor = tensor

    def on_epoch_end(self, epoch, logs=None):
        frame = self.sess.run(
            self.tensor
        )  # depending on the tensor, this might require a feed_dict
        self.visualizer.update(session=self.sess, frame=frame)
Пример #3
0
class BeholderCB(tf.keras.callbacks.Callback):
    """Keras callback for tensorboard beholder plugin: https://github.com/tensorflow/tensorboard/tree/master/tensorboard/plugins/beholder

    Args:
        logdir (str): path to the tensorboard log directory.
        sess: tensorflow session.
    """
    def __init__(self, logdir, sess):
        super(BeholderCB, self).__init__()
        self.beholder = Beholder(logdir=logdir)
        self.session = sess

    def on_epoch_end(self, epoch, logs=None):
        super(BeholderCB, self).on_epoch_end(epoch, logs)
        self.beholder.update(session=self.session)
Пример #4
0
class RolloutWorker(EvaluatorInterface):
    """Common experience collection class.

    This class wraps a policy instance and an environment class to
    collect experiences from the environment. You can create many replicas of
    this class as Ray actors to scale RL training.

    This class supports vectorized and multi-agent policy evaluation (e.g.,
    VectorEnv, MultiAgentEnv, etc.)

    Examples:
        >>> # Create a rollout worker and using it to collect experiences.
        >>> worker = RolloutWorker(
        ...   env_creator=lambda _: gym.make("CartPole-v0"),
        ...   policy=PGTFPolicy)
        >>> print(worker.sample())
        SampleBatch({
            "obs": [[...]], "actions": [[...]], "rewards": [[...]],
            "dones": [[...]], "new_obs": [[...]]})

        >>> # Creating a multi-agent rollout worker
        >>> worker = RolloutWorker(
        ...   env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
        ...   policies={
        ...       # Use an ensemble of two policies for car agents
        ...       "car_policy1":
        ...         (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
        ...       "car_policy2":
        ...         (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
        ...       # Use a single shared policy for all traffic lights
        ...       "traffic_light_policy":
        ...         (PGTFPolicy, Box(...), Discrete(...), {}),
        ...   },
        ...   policy_mapping_fn=lambda agent_id:
        ...     random.choice(["car_policy1", "car_policy2"])
        ...     if agent_id.startswith("car_") else "traffic_light_policy")
        >>> print(worker.sample())
        MultiAgentBatch({
            "car_policy1": SampleBatch(...),
            "car_policy2": SampleBatch(...),
            "traffic_light_policy": SampleBatch(...)})
    """
    @DeveloperAPI
    @classmethod
    def as_remote(cls, num_cpus=None, num_gpus=None, resources=None):
        return ray.remote(num_cpus=num_cpus,
                          num_gpus=num_gpus,
                          resources=resources)(cls)

    @DeveloperAPI
    def __init__(self,
                 env_creator,
                 policy,
                 policy_mapping_fn=None,
                 policies_to_train=None,
                 tf_session_creator=None,
                 batch_steps=100,
                 batch_mode="truncate_episodes",
                 episode_horizon=None,
                 preprocessor_pref="deepmind",
                 sample_async=False,
                 compress_observations=False,
                 num_envs=1,
                 observation_filter="NoFilter",
                 clip_rewards=None,
                 clip_actions=True,
                 env_config=None,
                 model_config=None,
                 policy_config=None,
                 worker_index=0,
                 monitor_path=None,
                 log_dir=None,
                 log_level=None,
                 callbacks=None,
                 input_creator=lambda ioctx: ioctx.default_sampler_input(),
                 input_evaluation=frozenset([]),
                 output_creator=lambda ioctx: NoopOutput(),
                 remote_worker_envs=False,
                 remote_env_batch_wait_ms=0,
                 soft_horizon=False,
                 _fake_sampler=False):
        """Initialize a rollout worker.

        Arguments:
            env_creator (func): Function that returns a gym.Env given an
                EnvContext wrapped configuration.
            policy (class|dict): Either a class implementing
                Policy, or a dictionary of policy id strings to
                (Policy, obs_space, action_space, config) tuples. If a
                dict is specified, then we are in multi-agent mode and a
                policy_mapping_fn should also be set.
            policy_mapping_fn (func): A function that maps agent ids to
                policy ids in multi-agent mode. This function will be called
                each time a new agent appears in an episode, to bind that agent
                to a policy for the duration of the episode.
            policies_to_train (list): Optional whitelist of policies to train,
                or None for all policies.
            tf_session_creator (func): A function that returns a TF session.
                This is optional and only useful with TFPolicy.
            batch_steps (int): The target number of env transitions to include
                in each sample batch returned from this worker.
            batch_mode (str): One of the following batch modes:
                "truncate_episodes": Each call to sample() will return a batch
                    of at most `batch_steps * num_envs` in size. The batch will
                    be exactly `batch_steps * num_envs` in size if
                    postprocessing does not change batch sizes. Episodes may be
                    truncated in order to meet this size requirement.
                "complete_episodes": Each call to sample() will return a batch
                    of at least `batch_steps * num_envs` in size. Episodes will
                    not be truncated, but multiple episodes may be packed
                    within one batch to meet the batch size. Note that when
                    `num_envs > 1`, episode steps will be buffered until the
                    episode completes, and hence batches may contain
                    significant amounts of off-policy data.
            episode_horizon (int): Whether to stop episodes at this horizon.
            preprocessor_pref (str): Whether to prefer RLlib preprocessors
                ("rllib") or deepmind ("deepmind") when applicable.
            sample_async (bool): Whether to compute samples asynchronously in
                the background, which improves throughput but can cause samples
                to be slightly off-policy.
            compress_observations (bool): If true, compress the observations.
                They can be decompressed with rllib/utils/compression.
            num_envs (int): If more than one, will create multiple envs
                and vectorize the computation of actions. This has no effect if
                if the env already implements VectorEnv.
            observation_filter (str): Name of observation filter to use.
            clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
                experience postprocessing. Setting to None means clip for Atari
                only.
            clip_actions (bool): Whether to clip action values to the range
                specified by the policy action space.
            env_config (dict): Config to pass to the env creator.
            model_config (dict): Config to use when creating the policy model.
            policy_config (dict): Config to pass to the policy. In the
                multi-agent case, this config will be merged with the
                per-policy configs specified by `policy`.
            worker_index (int): For remote workers, this should be set to a
                non-zero and unique value. This index is passed to created envs
                through EnvContext so that envs can be configured per worker.
            monitor_path (str): Write out episode stats and videos to this
                directory if specified.
            log_dir (str): Directory where logs can be placed.
            log_level (str): Set the root log level on creation.
            callbacks (dict): Dict of custom debug callbacks.
            input_creator (func): Function that returns an InputReader object
                for loading previous generated experiences.
            input_evaluation (list): How to evaluate the policy performance.
                This only makes sense to set when the input is reading offline
                data. The possible values include:
                  - "is": the step-wise importance sampling estimator.
                  - "wis": the weighted step-wise is estimator.
                  - "simulation": run the environment in the background, but
                    use this data for evaluation only and never for learning.
            output_creator (func): Function that returns an OutputWriter object
                for saving generated experiences.
            remote_worker_envs (bool): If using num_envs > 1, whether to create
                those new envs in remote processes instead of in the current
                process. This adds overheads, but can make sense if your envs
            remote_env_batch_wait_ms (float): Timeout that remote workers
                are waiting when polling environments. 0 (continue when at
                least one env is ready) is a reasonable default, but optimal
                value could be obtained by measuring your environment
                step / reset and model inference perf.
            soft_horizon (bool): Calculate rewards but don't reset the
                environment when the horizon is hit.
            _fake_sampler (bool): Use a fake (inf speed) sampler for testing.
        """

        global _global_worker
        _global_worker = self

        if log_level:
            logging.getLogger("ray.rllib").setLevel(log_level)

        if worker_index > 1:
            disable_log_once_globally()  # only need 1 worker to log
        elif log_level == "DEBUG":
            enable_periodic_logging()

        env_context = EnvContext(env_config or {}, worker_index)
        policy_config = policy_config or {}
        self.policy_config = policy_config
        self.callbacks = callbacks or {}
        self.worker_index = worker_index
        model_config = model_config or {}
        policy_mapping_fn = (policy_mapping_fn
                             or (lambda agent_id: DEFAULT_POLICY_ID))
        if not callable(policy_mapping_fn):
            raise ValueError(
                "Policy mapping function not callable. If you're using Tune, "
                "make sure to escape the function with tune.function() "
                "to prevent it from being evaluated as an expression.")
        self.env_creator = env_creator
        self.sample_batch_size = batch_steps * num_envs
        self.batch_mode = batch_mode
        self.compress_observations = compress_observations
        self.preprocessing_enabled = True
        self.last_batch = None
        self._fake_sampler = _fake_sampler
        self._beholder = None

        self.env = _validate_env(env_creator(env_context))
        if isinstance(self.env, MultiAgentEnv) or \
                isinstance(self.env, BaseEnv):

            def wrap(env):
                return env  # we can't auto-wrap these env types
        elif is_atari(self.env) and \
                not model_config.get("custom_preprocessor") and \
                preprocessor_pref == "deepmind":

            # Deepmind wrappers already handle all preprocessing
            self.preprocessing_enabled = False

            if clip_rewards is None:
                clip_rewards = True

            def wrap(env):
                env = wrap_deepmind(env,
                                    dim=model_config.get("dim"),
                                    framestack=model_config.get("framestack"))
                if monitor_path:
                    env = _monitor(env, monitor_path)
                return env
        else:

            def wrap(env):
                if monitor_path:
                    env = _monitor(env, monitor_path)
                return env

        self.env = wrap(self.env)

        def make_env(vector_index):
            return wrap(
                env_creator(
                    env_context.copy_with_overrides(
                        vector_index=vector_index, remote=remote_worker_envs)))

        self.tf_sess = None
        policy_dict = _validate_and_canonicalize(policy, self.env)
        self.policies_to_train = policies_to_train or list(policy_dict.keys())
        if _has_tensorflow_graph(policy_dict):
            if (ray.is_initialized()
                    and ray.worker._mode() != ray.worker.LOCAL_MODE
                    and not ray.get_gpu_ids()):
                logger.info("Creating policy evaluation worker {}".format(
                    worker_index) +
                            " on CPU (please ignore any CUDA init errors)")
            if not tf:
                raise ImportError("Could not import tensorflow")
            with tf.Graph().as_default():
                if tf_session_creator:
                    self.tf_sess = tf_session_creator()
                else:
                    self.tf_sess = tf.Session(config=tf.ConfigProto(
                        gpu_options=tf.GPUOptions(allow_growth=True)))
                with self.tf_sess.as_default():
                    self.policy_map, self.preprocessors = \
                        self._build_policy_map(policy_dict, policy_config)
        else:
            self.policy_map, self.preprocessors = self._build_policy_map(
                policy_dict, policy_config)

        self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
        if self.multiagent:
            if not ((isinstance(self.env, MultiAgentEnv)
                     or isinstance(self.env, ExternalMultiAgentEnv))
                    or isinstance(self.env, BaseEnv)):
                raise ValueError(
                    "Have multiple policies {}, but the env ".format(
                        self.policy_map) +
                    "{} is not a subclass of BaseEnv, MultiAgentEnv or "
                    "ExternalMultiAgentEnv?".format(self.env))

        self.filters = {
            policy_id: get_filter(observation_filter,
                                  policy.observation_space.shape)
            for (policy_id, policy) in self.policy_map.items()
        }
        if self.worker_index == 0:
            logger.info("Built filter map: {}".format(self.filters))

        # Always use vector env for consistency even if num_envs = 1
        self.async_env = BaseEnv.to_base_env(
            self.env,
            make_env=make_env,
            num_envs=num_envs,
            remote_envs=remote_worker_envs,
            remote_env_batch_wait_ms=remote_env_batch_wait_ms)
        self.num_envs = num_envs

        if self.batch_mode == "truncate_episodes":
            unroll_length = batch_steps
            pack_episodes = True
        elif self.batch_mode == "complete_episodes":
            unroll_length = float("inf")  # never cut episodes
            pack_episodes = False  # sampler will return 1 episode per poll
        else:
            raise ValueError("Unsupported batch mode: {}".format(
                self.batch_mode))

        self.io_context = IOContext(log_dir, policy_config, worker_index, self)
        self.reward_estimators = []
        for method in input_evaluation:
            if method == "simulation":
                logger.warning(
                    "Requested 'simulation' input evaluation method: "
                    "will discard all sampler outputs and keep only metrics.")
                sample_async = True
            elif method == "is":
                ise = ImportanceSamplingEstimator.create(self.io_context)
                self.reward_estimators.append(ise)
            elif method == "wis":
                wise = WeightedImportanceSamplingEstimator.create(
                    self.io_context)
                self.reward_estimators.append(wise)
            else:
                raise ValueError(
                    "Unknown evaluation method: {}".format(method))

        if sample_async:
            self.sampler = AsyncSampler(self.async_env,
                                        self.policy_map,
                                        policy_mapping_fn,
                                        self.preprocessors,
                                        self.filters,
                                        clip_rewards,
                                        unroll_length,
                                        self.callbacks,
                                        horizon=episode_horizon,
                                        pack=pack_episodes,
                                        tf_sess=self.tf_sess,
                                        clip_actions=clip_actions,
                                        blackhole_outputs="simulation"
                                        in input_evaluation,
                                        soft_horizon=soft_horizon)
            self.sampler.start()
        else:
            self.sampler = SyncSampler(self.async_env,
                                       self.policy_map,
                                       policy_mapping_fn,
                                       self.preprocessors,
                                       self.filters,
                                       clip_rewards,
                                       unroll_length,
                                       self.callbacks,
                                       horizon=episode_horizon,
                                       pack=pack_episodes,
                                       tf_sess=self.tf_sess,
                                       clip_actions=clip_actions,
                                       soft_horizon=soft_horizon)

        self.input_reader = input_creator(self.io_context)
        assert isinstance(self.input_reader, InputReader), self.input_reader
        self.output_writer = output_creator(self.io_context)
        assert isinstance(self.output_writer, OutputWriter), self.output_writer

        logger.debug(
            "Created rollout worker with env {} ({}), policies {}".format(
                self.async_env, self.env, self.policy_map))

    @override(EvaluatorInterface)
    def sample(self):
        """Evaluate the current policies and return a batch of experiences.

        Return:
            SampleBatch|MultiAgentBatch from evaluating the current policies.
        """

        if self._fake_sampler and self.last_batch is not None:
            return self.last_batch

        if log_once("sample_start"):
            logger.info("Generating sample batch of size {}".format(
                self.sample_batch_size))

        batches = [self.input_reader.next()]
        steps_so_far = batches[0].count

        # In truncate_episodes mode, never pull more than 1 batch per env.
        # This avoids over-running the target batch size.
        if self.batch_mode == "truncate_episodes":
            max_batches = self.num_envs
        else:
            max_batches = float("inf")

        while steps_so_far < self.sample_batch_size and len(
                batches) < max_batches:
            batch = self.input_reader.next()
            steps_so_far += batch.count
            batches.append(batch)
        batch = batches[0].concat_samples(batches)

        if self.callbacks.get("on_sample_end"):
            self.callbacks["on_sample_end"]({"worker": self, "samples": batch})

        # Always do writes prior to compression for consistency and to allow
        # for better compression inside the writer.
        self.output_writer.write(batch)

        # Do off-policy estimation if needed
        if self.reward_estimators:
            for sub_batch in batch.split_by_episode():
                for estimator in self.reward_estimators:
                    estimator.process(sub_batch)

        if log_once("sample_end"):
            logger.info("Completed sample batch:\n\n{}\n".format(
                summarize(batch)))

        if self.compress_observations == "bulk":
            batch.compress(bulk=True)
        elif self.compress_observations:
            batch.compress()

        if self._fake_sampler:
            self.last_batch = batch
        return batch

    @DeveloperAPI
    @ray.method(num_return_vals=2)
    def sample_with_count(self):
        """Same as sample() but returns the count as a separate future."""
        batch = self.sample()
        return batch, batch.count

    @override(EvaluatorInterface)
    def get_weights(self, policies=None):
        if policies is None:
            policies = self.policy_map.keys()
        return {
            pid: policy.get_weights()
            for pid, policy in self.policy_map.items() if pid in policies
        }

    @override(EvaluatorInterface)
    def set_weights(self, weights):
        for pid, w in weights.items():
            self.policy_map[pid].set_weights(w)

    @override(EvaluatorInterface)
    def compute_gradients(self, samples):
        if log_once("compute_gradients"):
            logger.info("Compute gradients on:\n\n{}\n".format(
                summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            grad_out, info_out = {}, {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "compute_gradients")
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid]._build_compute_gradients(
                            builder, batch))
                grad_out = {k: builder.get(v) for k, v in grad_out.items()}
                info_out = {k: builder.get(v) for k, v in info_out.items()}
            else:
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid].compute_gradients(batch))
        else:
            grad_out, info_out = (
                self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
        info_out["batch_count"] = samples.count
        if log_once("grad_out"):
            logger.info("Compute grad info:\n\n{}\n".format(
                summarize(info_out)))
        return grad_out, info_out

    @override(EvaluatorInterface)
    def apply_gradients(self, grads):
        if log_once("apply_gradients"):
            logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
        if isinstance(grads, dict):
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "apply_gradients")
                outputs = {
                    pid:
                    self.policy_map[pid]._build_apply_gradients(builder, grad)
                    for pid, grad in grads.items()
                }
                return {k: builder.get(v) for k, v in outputs.items()}
            else:
                return {
                    pid: self.policy_map[pid].apply_gradients(g)
                    for pid, g in grads.items()
                }
        else:
            return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)

    @override(EvaluatorInterface)
    def learn_on_batch(self, samples):
        if log_once("learn_on_batch"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            info_out = {}
            to_fetch = {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
            else:
                builder = None
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                policy = self.policy_map[pid]
                if builder and hasattr(policy, "_build_learn_on_batch"):
                    to_fetch[pid] = policy._build_learn_on_batch(
                        builder, batch)
                else:
                    info_out[pid] = policy.learn_on_batch(batch)
            info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
        else:
            learn_on_batch_outputs = self.policy_map[
                DEFAULT_POLICY_ID].learn_on_batch(samples)

            if isinstance(learn_on_batch_outputs, tuple):
                info_out, beholder_arrays = learn_on_batch_outputs
            else:
                info_out, beholder_arrays = learn_on_batch_outputs, {}

            if self.policy_config["evaluation_config"]["beholder"]:
                with self.tf_sess.graph.as_default():
                    if self._beholder is None:
                        self._beholder = Beholder(self.io_context.log_dir)
                    self._beholder.update(self.tf_sess,
                                          arrays=beholder_arrays or None)

        if log_once("learn_out"):
            logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
        return info_out

    @DeveloperAPI
    def get_metrics(self):
        """Returns a list of new RolloutMetric objects from evaluation."""

        out = self.sampler.get_metrics()
        for m in self.reward_estimators:
            out.extend(m.get_metrics())
        return out

    @DeveloperAPI
    def foreach_env(self, func):
        """Apply the given function to each underlying env instance."""

        envs = self.async_env.get_unwrapped()
        if not envs:
            return [func(self.async_env)]
        else:
            return [func(e) for e in envs]

    @DeveloperAPI
    def get_policy(self, policy_id=DEFAULT_POLICY_ID):
        """Return policy for the specified id, or None.

        Arguments:
            policy_id (str): id of policy to return.
        """

        return self.policy_map.get(policy_id)

    @DeveloperAPI
    def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
        """Apply the given function to the specified policy."""

        return func(self.policy_map[policy_id])

    @DeveloperAPI
    def foreach_policy(self, func):
        """Apply the given function to each (policy, policy_id) tuple."""

        return [func(policy, pid) for pid, policy in self.policy_map.items()]

    @DeveloperAPI
    def foreach_trainable_policy(self, func):
        """Apply the given function to each (policy, policy_id) tuple.

        This only applies func to policies in `self.policies_to_train`."""

        return [
            func(policy, pid) for pid, policy in self.policy_map.items()
            if pid in self.policies_to_train
        ]

    @DeveloperAPI
    def sync_filters(self, new_filters):
        """Changes self's filter to given and rebases any accumulated delta.

        Args:
            new_filters (dict): Filters with new state to update local copy.
        """
        assert all(k in new_filters for k in self.filters)
        for k in self.filters:
            self.filters[k].sync(new_filters[k])

    @DeveloperAPI
    def get_filters(self, flush_after=False):
        """Returns a snapshot of filters.

        Args:
            flush_after (bool): Clears the filter buffer state.

        Returns:
            return_filters (dict): Dict for serializable filters
        """
        return_filters = {}
        for k, f in self.filters.items():
            return_filters[k] = f.as_serializable()
            if flush_after:
                f.clear_buffer()
        return return_filters

    @DeveloperAPI
    def save(self):
        filters = self.get_filters(flush_after=True)
        state = {
            pid: self.policy_map[pid].get_state()
            for pid in self.policy_map
        }
        return pickle.dumps({"filters": filters, "state": state})

    @DeveloperAPI
    def restore(self, objs):
        objs = pickle.loads(objs)
        self.sync_filters(objs["filters"])
        for pid, state in objs["state"].items():
            self.policy_map[pid].set_state(state)

    @DeveloperAPI
    def set_global_vars(self, global_vars):
        self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))

    @DeveloperAPI
    def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
        self.policy_map[policy_id].export_model(export_dir)

    @DeveloperAPI
    def export_policy_checkpoint(self,
                                 export_dir,
                                 filename_prefix="model",
                                 policy_id=DEFAULT_POLICY_ID):
        self.policy_map[policy_id].export_checkpoint(export_dir,
                                                     filename_prefix)

    @DeveloperAPI
    def stop(self):
        self.async_env.stop()

    def _build_policy_map(self, policy_dict, policy_config):
        policy_map = {}
        preprocessors = {}
        for name, (cls, obs_space, act_space,
                   conf) in sorted(policy_dict.items()):
            logger.debug("Creating policy for {}".format(name))
            merged_conf = merge_dicts(policy_config, conf)
            if self.preprocessing_enabled:
                preprocessor = ModelCatalog.get_preprocessor_for_space(
                    obs_space, merged_conf.get("model"))
                preprocessors[name] = preprocessor
                obs_space = preprocessor.observation_space
            else:
                preprocessors[name] = NoPreprocessor(obs_space)
            if isinstance(obs_space, gym.spaces.Dict) or \
                    isinstance(obs_space, gym.spaces.Tuple):
                raise ValueError(
                    "Found raw Tuple|Dict space as input to policy. "
                    "Please preprocess these observations with a "
                    "Tuple|DictFlatteningPreprocessor.")
            if tf:
                with tf.variable_scope(name):
                    policy_map[name] = cls(obs_space, act_space, merged_conf)
            else:
                policy_map[name] = cls(obs_space, act_space, merged_conf)
        if self.worker_index == 0:
            logger.info("Built policy map: {}".format(policy_map))
            logger.info("Built preprocessor map: {}".format(preprocessors))
        return policy_map, preprocessors

    def __del__(self):
        if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
            self.sampler.shutdown = True
Пример #5
0
def train(hps, files):
    ngpus = hps.ngpus
    config = tf.ConfigProto()
    if ngpus > 1:
        try:
            import horovod.tensorflow as hvd
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(hvd.local_rank())
        except ImportError:
            hvd = None
            print("horovod not available, can only use 1 gpu")
            ngpus = 1

    # todo: organize
    current_res_w = hps.current_res_w
    res_multiplier = current_res_w // hps.start_res_w
    current_res_h = hps.start_res_h * res_multiplier

    tfrecord_input = any('.tfrecords' in fname for fname in files)
    # if using tfrecord, assume dataset is duplicated across multiple resolutions
    if tfrecord_input:
        num_files = 0
        for fname in [fname for fname in files if "res%d" % current_res_w in fname]:
            for record in tf.compat.v1.python_io.tf_record_iterator(fname):
                num_files += 1
    else:
        num_files = len(files)

    label_list = []
    total_classes = 0
    if hps.label_file:
        do_cgan = True
        label_list, total_classes = build_label_list_from_file(hps.label_file)
    else:
        do_cgan = False

    print("dataset has %d files" % num_files)
    try:
        batch_size = int(hps.batch_size)
        try_schedule = False
    except ValueError:
        try_schedule = True
    if try_schedule:
        batch_schedule = ast.literal_eval(hps.batch_size)
    else:
        batch_schedule = None

    #  always generate 32 sample images (should be feasible at high resolutions due to no training)
    #  will probably need to edit for > 128x128
    sample_batch = 32
    sample_latent_numpy = np.random.normal(0., 1., [sample_batch, 512])

    if do_cgan:
        examples_per_class = sample_batch // total_classes
        remainder = sample_batch % total_classes
        sample_cgan_latent_numpy = None
        for i in range(0, total_classes):
            class_vector = [0.] * total_classes
            class_vector[i] = 1.
            if sample_cgan_latent_numpy is None:
                sample_cgan_latent_numpy = [class_vector] * (examples_per_class + remainder)
            else:
                sample_cgan_latent_numpy += [class_vector] * examples_per_class
        sample_cgan_latent_numpy = np.array(sample_cgan_latent_numpy)

    use_beholder = hps.use_beholder
    if use_beholder:
        try:
            from tensorboard.plugins.beholder import Beholder
        except ImportError:
            print("Could not import beholder")
            use_beholder = False
    while current_res_w <= hps.res_w:
        if ngpus > 1:
            hvd.init()
        print("building graph")
        if batch_schedule is not None:
            batch_size = batch_schedule[current_res_w]
            print("res %d batch size is now %d" % (current_res_w, batch_size))
        gen_model, mapping_network, dis_model, sampling_model = \
            build_models(hps,
                         current_res_w,
                         use_ema_sampling=True,
                         num_classes=total_classes,
                         label_list=label_list if hps.conditional_type == "acgan" else None)
        with tf.name_scope("optimizers"):
            optimizer_d, optimizer_g, optimizer_m = build_optimizers(hps)
            if ngpus > 1:
                optimizer_d = hvd.DistributedOptimizer(optimizer_d)
                optimizer_g = hvd.DistributedOptimizer(optimizer_g)
                optimizer_m = hvd.DistributedOptimizer(optimizer_m)
        with tf.name_scope("data"):
            num_shards = None if ngpus == 1 else ngpus
            shard_index = None if ngpus == 1 else hvd.rank()
            it = build_data_iterator(hps, files, current_res_h, current_res_w, batch_size, label_list=label_list,
                                     num_shards=num_shards, shard_index=shard_index)
            next_batch = it.get_next()
            real_image = next_batch['data']

            fake_latent1 = tf.random_normal([batch_size, 512], 0., 1., name="fake_latent")
            fake_latent2 = tf.random_normal([batch_size, 512], 0., 1., name="fake_latent")

            fake_label_dict = None
            real_label_dict = None
            if do_cgan:
                fake_label_dict = {}
                real_label_dict = {}
                for label in label_list:
                    if hps.cond_uniform_fake:
                        distribution = np.ones_like([label.probabilities])
                    else:
                        distribution = np.log([label.probabilities])
                    fake_labels = tf.random.categorical(distribution, batch_size)
                    if label.multi_dim is False:
                        normalized_labels = (fake_labels - tf.reduce_min(fake_labels)) / \
                                            (tf.reduce_max(fake_labels) - tf.reduce_min(fake_labels))
                        fake_labels = tf.reshape(normalized_labels, [batch_size, 1])
                    else:
                        fake_labels = tf.reshape(tf.one_hot(fake_labels, label.num_classes),
                                                 [batch_size, label.num_classes])
                    fake_label_dict[label.name] = fake_labels
                    real_label_dict[label.name] = next_batch[label.name]
                    #fake_label_list.append(fake_labels)
                    # ideally would handle one dimensional labels differently, theory isn't well supported
                    # for that though (example: categorical values of short, medium, tall are on one dimension)
                    # real_labels = tf.reshape(tf.one_hot(tf.cast(next_batch[label.name], tf.int32), num_classes),
                    #                          [batch_size, num_classes])
                    #real_label_list.append(real_labels)
                fake_label_tensor = tf.concat([fake_label_dict[l] for l in fake_label_dict.keys()], axis=-1)
                real_label_tensor = tf.concat([real_label_dict[l] for l in real_label_dict.keys()], axis=-1)
            sample_latent = tf.constant(sample_latent_numpy, dtype=tf.float32, name="sample_latent")
            if do_cgan:
                sample_cgan_w = tf.constant(sample_cgan_latent_numpy, dtype=tf.float32, name="sample_cgan_latent")
            alpha_ph = tf.placeholder(shape=(), dtype=tf.float32, name="alpha")
            #  From Fig 2: "During a resolution transition,
            #  we interpolate between two resolutions of the real images"
            real_image = real_image*alpha_ph + \
                (1-alpha_ph)*upsample(downsample_nv(real_image),
                              method="nearest_neighbor")
            real_image = upsample(real_image, method='nearest_neighbor', factor=hps.res_w//current_res_w)
        if do_cgan:
            with tf.name_scope("gen_synthesis"):
                fake_image = gen_model(alpha_ph, zs=[fake_latent1, fake_latent2], mapping_network=mapping_network,
                                       cgan_w=fake_label_tensor, random_crossover=True)
            real_logit, real_class_logits = dis_model(real_image, alpha_ph,
                                                      real_label_tensor if hps.conditional_type == "proj" else
                                                      None)
            fake_logit, fake_class_logits = dis_model(fake_image, alpha_ph,
                                                      fake_label_tensor if hps.conditional_type == "proj" else
                                                      None)
        else:
            with tf.name_scope("gen_synthesis"):
                fake_image = gen_model(alpha_ph, zs=[fake_latent1, fake_latent2], mapping_network=mapping_network,
                                       random_crossover=True)
            real_logit, real_class_logits = dis_model(real_image, alpha_ph)  # todo: make work with other labels
            fake_logit, fake_class_logits = dis_model(fake_image, alpha_ph)

        with tf.name_scope("gen_sampling"):

            average_latent = tf.constant(np.random.normal(0., 1., [10000, 512]), dtype=tf.float32)
            low_psi = 0.20
            if hps.map_cond:
                class_vector = [0.] * total_classes
                class_vector[0] = 1. # one hot encoding
                average_w = tf.reduce_mean(mapping_network(tf.concat([average_latent,
                                                                      [class_vector]*10000], axis=-1)), axis=0)
                sample_latent_lowpsi = average_w + low_psi * \
                                       (mapping_network(tf.concat([sample_latent,
                                                                   [class_vector]*sample_batch], axis=-1)) - average_w)
            else:
                average_w = tf.reduce_mean(mapping_network(average_latent), axis=0)
                sample_latent_lowpsi = average_w + low_psi * (mapping_network(sample_latent) - average_w)
            average_w_batch = tf.tile(tf.reshape(average_w, [1, 512]), [sample_batch, 1])
            if do_cgan:
                sample_img_lowpsi = sampling_model(alpha_ph, intermediate_ws=sample_latent_lowpsi,
                                                   cgan_w=sample_cgan_w)
                sample_img_base = sampling_model(alpha_ph, zs=sample_latent, mapping_network=mapping_network,
                                                 cgan_w=sample_cgan_w)
                sample_img_mode = sampling_model(alpha_ph, intermediate_ws=average_w_batch,
                                                 cgan_w=sample_cgan_w)
                sample_img_mode = tf.concat([sample_img_mode[0:2] + sample_img_mode[-3:-1]], axis=0)
            else:
                sample_img_lowpsi = sampling_model(alpha_ph, intermediate_ws=sample_latent_lowpsi)
                sample_img_base = sampling_model(alpha_ph, zs=sample_latent, mapping_network=mapping_network)
                sample_img_mode = sampling_model(alpha_ph, intermediate_ws=average_w_batch)[0:4]
            sample_images = tf.concat([sample_img_lowpsi, sample_img_mode, sample_img_base], axis=0)
            sampling_model_init_ops = weight_following_ema_ops(average_model=sampling_model,
                                                               reference_model=gen_model)
            #sample_img_base = gen_model(sample_latent, alpha_ph, mapping_network)

        with tf.name_scope("loss"):
            loss_discriminator, loss_generator = hps.loss_fn(real_logit, fake_logit)
            if real_class_logits is not None:
                for label in label_list:
                    label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=next_batch[label.name],
                                                                         logits=real_class_logits[label.name])
                    loss_discriminator += label_loss * hps.cond_weight * 1./(len(label_list))
                    tf.summary.scalar("label_loss_real", tf.reduce_mean(label_loss))
            if fake_class_logits is not None:
                for label in label_list:
                    label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=fake_label_dict[label.name],
                                                                         logits=fake_class_logits[label.name])
                    loss_discriminator += label_loss * hps.cond_weight * 1./(len(label_list))
                    tf.summary.scalar("label_loss_fake", tf.reduce_mean(label_loss))

                    loss_generator += label_loss * hps.cond_weight * 1./(len(label_list))
            if hps.gp_fn:
                gp = hps.gp_fn(fake_image, real_image, dis_model, alpha_ph, real_label_dict,
                               conditional_type=hps.conditional_type)
                tf.summary.scalar("gradient_penalty", tf.reduce_mean(gp))
                loss_discriminator += hps.lambda_gp*gp
            dp = drift_penalty(real_logit)
            tf.summary.scalar("drift_penalty", tf.reduce_mean(dp))
            if hps.lambda_drift != 0.:
                loss_discriminator = tf.expand_dims(loss_discriminator, -1) + hps.lambda_drift * dp

            loss_discriminator_avg = tf.reduce_mean(loss_discriminator)
            loss_generator_avg = tf.reduce_mean(loss_generator)
        with tf.name_scope("train"):
            train_step_d = optimizer_d.minimize(loss_discriminator_avg, var_list=dis_model.trainable_variables)
            # todo: test this
            with tf.control_dependencies(weight_following_ema_ops(average_model=sampling_model,
                                                                  reference_model=gen_model)):
                train_step_g = [optimizer_g.minimize(loss_generator_avg, var_list=gen_model.trainable_variables)]
            if hps.do_mapping_network:
                train_step_g.append(
                    optimizer_m.minimize(loss_generator_avg, var_list=mapping_network.trainable_variables))
        with tf.name_scope("summary"):
            tf.summary.histogram("real_scores", real_logit)
            tf.summary.scalar("loss_discriminator", loss_discriminator_avg)
            tf.summary.scalar("loss_generator", loss_generator_avg)
            tf.summary.scalar("real_logit", tf.reduce_mean(real_logit))
            tf.summary.scalar("fake_logit", tf.reduce_mean(fake_logit))
            tf.summary.histogram("real_logit", real_logit)
            tf.summary.histogram("fake_logit", fake_logit)
            tf.summary.scalar("alpha", alpha_ph)
            merged = tf.summary.merge_all()
            image_summary_real = generate_image_summary(real_image, "real")
            image_summary_fake_avg = generate_image_summary(sample_images, "fake_avg")
            #image_summary_fake = generate_image_summary(sample_img_base, "fake")
        global_step = tf.train.get_or_create_global_step()
        if hps.profile:
            builder = tf.profiler.ProfileOptionBuilder
            opts = builder(builder.time_and_memory()).order_by('micros').build()

        with tf.contrib.tfprof.ProfileContext(hps.model_dir,
                                              trace_steps=[],
                                              dump_steps=[]) as pctx:
            with tf.Session(config=config) as sess:
                #if hps.tboard_debug:
                #    sess = tf_debug.TensorBoardDebugWrapperSession(sess, "localhost:6064")
                #elif hps.cli_debug:
                #    sess = tf_debug.LocalCLIDebugWrapperSession(sess)
                sess.run(tf.global_variables_initializer())
                sess.run(sampling_model_init_ops)
                alpha = 1.
                step = 0
                if os.path.exists(hps.save_paths.gen_model) and os.path.exists(hps.save_paths.dis_model):
                    if ngpus == 1 or hvd.rank() == 0:
                        print("restoring")
                        restore_models_and_optimizers(sess, gen_model, dis_model, mapping_network,
                                                      sampling_model,
                                                      optimizer_g, optimizer_d, optimizer_m, hps.save_paths)
                if os.path.exists(hps.save_paths.alpha) and os.path.exists(hps.save_paths.step):
                    alpha, step = restore_alpha_and_step(hps.save_paths)
                
                print("alpha")
                print(alpha)

                if alpha != 1.:
                    alpha_inc = 1. / (hps.epochs_per_res * (num_files / batch_size))
                else:
                    alpha_inc = 0.
                writer_path = \
                    os.path.join(hps.model_dir, "summary_%d" % current_res_w, "alpha_start_%d" % alpha)
                if use_beholder:
                    beholder = Beholder(writer_path)
                writer = tf.summary.FileWriter(writer_path, sess.graph)
                writer.add_summary(image_summary_real.eval(feed_dict={alpha_ph: alpha}), step)
                print("Starting res %d training" % current_res_w)
                t = trange(hps.epochs_per_res * num_files // batch_size, desc='Training')


                if ngpus > 1:
                    sess.run(hvd.broadcast_global_variables(0))
                for phase_step in t:
                    try:
                        for i in range(0, hps.ncritic):
                            if hps.profile:
                                pctx.trace_next_step()
                                pctx.dump_next_step()
                            if step % 5 == 0:
                                summary, ld, _ = sess.run([merged,
                                                           loss_discriminator_avg,
                                                           train_step_d if not hps.no_train else tf.no_op()],
                                                          feed_dict={alpha_ph: alpha})
                                writer.add_summary(summary, step)
                            else:

                                ld, _ = sess.run([loss_discriminator_avg,
                                                  train_step_d if not hps.no_train else tf.no_op()],
                                                  feed_dict={alpha_ph: alpha})
                            if hps.profile:
                                pctx.profiler.profile_operations(options=opts)
                        if hps.profile:
                            pctx.trace_next_step()
                            pctx.dump_next_step()
                        lg, _ = sess.run([loss_generator_avg,
                                          train_step_g if not hps.no_train else tf.no_op()],
                                         feed_dict={alpha_ph: alpha})
                        if hps.profile:
                            pctx.profiler.profile_operations(options=opts)
                        alpha = min(alpha+alpha_inc, 1.)

                        #print("step: %d" % step)
                        #print("loss_d: %f" % ld)
                        #print("loss_g: %f\n" % lg)
                        t.set_description('Overall step %d, loss d %f, loss g %f' % (step+1, ld, lg))
                        if use_beholder:
                            try:
                                beholder.update(session=sess)
                            except Exception as e:
                                print("Beholder failed: " + str(e))
                                use_beholder = False

                        if phase_step < 5 or (phase_step < 500 and phase_step % 10 == 0) or (step % 1000 == 0):
                            writer.add_summary(image_summary_fake_avg.eval(
                                feed_dict={alpha_ph: alpha}), step)
                            #writer.add_summary(image_summary_fake.eval(
                            #    feed_dict={alpha_ph: alpha}), step)
                        if hps.steps_per_save is not None and step % hps.steps_per_save == 0 and (ngpus == 1 or hvd.rank() == 0):
                            save_models_and_optimizers(sess,
                                                       gen_model, dis_model, mapping_network,
                                                       sampling_model,
                                                       optimizer_g, optimizer_d, optimizer_m,
                                                       hps.save_paths)
                            save_alpha_and_step(1. if alpha_inc != 0. else 0., step, hps.save_paths)
                        step += 1
                    except tf.errors.OutOfRangeError:
                        break
                assert (abs(alpha - 1.) < .1), "Alpha should be close to 1., not %f" % alpha  # alpha close to 1. (dataset divisible by batch_size for small sets)
                if ngpus == 1 or hvd.rank() == 0:
                    print(1. if alpha_inc != 0. else 0.)
                    save_models_and_optimizers(sess,
                                               gen_model, dis_model, mapping_network, sampling_model,
                                               optimizer_g, optimizer_d, optimizer_m,
                                               hps.save_paths)
                    backup_model_for_this_phase(hps.save_paths, writer_path)
                save_alpha_and_step(1. if alpha_inc != 0. else 0., step, hps.save_paths)
                #  Will generate Out of range errors, see if it's easy to save a tensor so get_next() doesn't need
                #  a new value
                #writer.add_summary(image_summary_real.eval(feed_dict={alpha_ph: 1.}), step)
                #writer.add_summary(image_summary_fake.eval(feed_dict={alpha_ph: 1.}), step)

        tf.reset_default_graph()
        if alpha_inc == 0:
            current_res_h *= 2
            current_res_w *= 2
Пример #6
0
class C51Agent():
    class Model():
        def __init__(self, session, num_actions, train_net):
            self.sess = session

            # Input
            self.x = tf.placeholder(name="state",
                                    dtype=tf.uint8,
                                    shape=(None, params.STATE_DIMENSIONS[0],
                                           params.STATE_DIMENSIONS[1],
                                           params.HISTORY_LEN))

            self.normalized_x = tf.cast(self.x, dtype=tf.float32) / 255.0

            with tf.variable_scope("common"):
                # Convolutional Layers
                self.conv_outputs = []
                for CONV_LAYER_SPEC in params.CONVOLUTIONAL_LAYERS_SPEC:
                    self.conv_outputs.append(
                        tf.layers.conv2d(
                            name="conv_layer_" +
                            str(len(self.conv_outputs) + 1),
                            inputs=self.normalized_x if len(self.conv_outputs)
                            == 0 else self.conv_outputs[-1],
                            filters=CONV_LAYER_SPEC["filters"],
                            kernel_size=CONV_LAYER_SPEC["kernel_size"],
                            strides=CONV_LAYER_SPEC["strides"],
                            activation=tf.nn.relu))

                # Flatten
                self.flattened_conv_output = tf.layers.flatten(
                    name="conv_output_flattener", inputs=self.conv_outputs[-1])

                # Hidden Layer
                self.dense_outputs = []
                for DENSE_LAYER_SPEC in params.DENSE_LAYERS_SPEC:
                    self.dense_outputs.append(
                        tf.layers.dense(name="dense_layer_" +
                                        str(len(self.dense_outputs) + 1),
                                        inputs=self.flattened_conv_output
                                        if len(self.dense_outputs) == 0 else
                                        self.dense_outputs[-1],
                                        units=DENSE_LAYER_SPEC,
                                        activation=tf.nn.relu))

                # State-Action-Value Distributions (as a flattened vector)
                self.flattened_q_dist = tf.layers.dense(
                    name="flattened_action_value_dist_logits",
                    inputs=self.dense_outputs[-1],
                    units=num_actions * params.NB_ATOMS)

                # Unflatten
                self.q_dist_logits = tf.reshape(
                    self.flattened_q_dist, [-1, num_actions, params.NB_ATOMS],
                    name="reshape_q_dist_logits")

                # Softmax State-Action-Value Distributions (per action)
                self.q_dist = tf.nn.softmax(self.q_dist_logits,
                                            name="action_value_dist",
                                            axis=-1)

                # Multiply bin probabilities by value
                self.delta_z = (params.V_MAX -
                                params.V_MIN) / (params.NB_ATOMS - 1)
                self.Z = tf.range(start=params.V_MIN,
                                  limit=params.V_MAX + self.delta_z,
                                  delta=self.delta_z)
                self.post_mul = self.q_dist * tf.reshape(
                    self.Z, [1, 1, params.NB_ATOMS])

                # Take sum to get the expected state-action values for each action
                self.actions = tf.reduce_sum(self.post_mul, axis=2)

                self.batch_size_range = tf.range(start=0,
                                                 limit=tf.shape(self.x)[0])

            if not train_net:
                self.targ_q_net_max = tf.summary.scalar(
                    "targ_q_net_max", tf.reduce_max(self.actions))
                self.targ_q_net_mean = tf.summary.scalar(
                    "targ_q_net_mean", tf.reduce_mean(self.actions))
                self.targ_q_net_min = tf.summary.scalar(
                    "targ_q_net_min", tf.reduce_min(self.actions))

                # Find argmax action given expected state-action values at next state
                self.argmax_action = tf.argmax(self.actions,
                                               axis=-1,
                                               output_type=tf.int32)

                # Get it's corresponding distribution (this is the target distribution)
                self.argmax_action_distribution = tf.gather_nd(
                    self.q_dist,
                    tf.stack((self.batch_size_range, self.argmax_action),
                             axis=1))  # Axis = 1 => [N, 2]

                self.mean_argmax_next_state_value = tf.summary.scalar(
                    "mean_argmax_q_target",
                    tf.reduce_mean(self.Z * self.argmax_action_distribution))

                # Placeholder for reward
                self.r = tf.placeholder(name="reward",
                                        dtype=tf.float32,
                                        shape=(None, ))
                self.t = tf.placeholder(name="terminal",
                                        dtype=tf.uint8,
                                        shape=(None, ))

                # Compute Tz (Bellman Operator) on atom of expected state-action-value
                # r + gamma * z clipped to [V_min, V_max]
                self.Tz = tf.clip_by_value(
                    tf.reshape(self.r, [-1, 1]) + 0.99 *
                    tf.cast(tf.reshape(self.t, [-1, 1]), tf.float32) * self.Z,
                    clip_value_min=params.V_MIN,
                    clip_value_max=params.V_MAX)

                # Compute bin number (will be floating point).
                self.b = (self.Tz - params.V_MIN) / self.delta_z

                # Lower and Upper Bins.
                self.l = tf.floor(self.b)
                self.u = tf.ceil(self.b)

                # Add weight to the lower bin based on distance from upper bin to
                # approximate bin index b. (0--b--1. If b = 0.3. Then, assign bin
                # 0, p(b) * 0.7 weight and bin 1, p(Z = z_b) * 0.3 weight.)
                self.indexable_l = tf.stack(
                    (
                        tf.reshape(self.batch_size_range, [-1, 1]) * tf.ones(
                            (1, params.NB_ATOMS), dtype=tf.int32),
                        # BATCH_SIZE_RANGE x NB_ATOMS [[0, ...], [1, ...], ...]
                        tf.cast(self.l, dtype=tf.int32)),
                    axis=-1)
                self.m_l_vals = self.argmax_action_distribution * (self.u -
                                                                   self.b)
                self.m_l = tf.scatter_nd(tf.reshape(self.indexable_l, [-1, 2]),
                                         tf.reshape(self.m_l_vals, [-1]),
                                         tf.shape(self.l))

                # Add weight to the lower bin based on distance from upper bin to
                # approximate bin index b.
                self.indexable_u = tf.stack(
                    (
                        tf.reshape(self.batch_size_range, [-1, 1]) * tf.ones(
                            (1, params.NB_ATOMS), dtype=tf.int32),
                        # BATCH_SIZE_RANGE x NB_ATOMS [[0, ...], [1, ...], ...]
                        tf.cast(self.u, dtype=tf.int32)),
                    axis=-1)
                self.m_u_vals = self.argmax_action_distribution * (self.b -
                                                                   self.l)
                self.m_u = tf.scatter_nd(tf.reshape(self.indexable_u, [-1, 2]),
                                         tf.reshape(self.m_u_vals, [-1]),
                                         tf.shape(self.u))

                # Add Contributions of both upper and lower parts and
                # stop gradient to not update the target network.
                self.m = tf.stop_gradient(tf.squeeze(self.m_l + self.m_u))

                self.weighted_m = tf.clip_by_value(self.m * self.Z,
                                                   clip_value_min=params.V_MIN,
                                                   clip_value_max=params.V_MAX)

                self.weighted_m_mean = tf.summary.scalar(
                    "mean_q_target", tf.reduce_mean(self.weighted_m))

                self.targ_dist = tf.summary.histogram("target_distribution",
                                                      self.weighted_m)

                self.targn_summary = tf.summary.merge([
                    self.targ_dist, self.weighted_m_mean, self.targ_q_net_max,
                    self.targ_q_net_mean, self.targ_q_net_min,
                    self.mean_argmax_next_state_value
                ])
            else:
                self.trn_q_net_max = tf.summary.scalar(
                    "trn_q_net_max", tf.reduce_max(self.actions))
                self.trn_q_net_mean = tf.summary.scalar(
                    "trn_q_net_mean", tf.reduce_mean(self.actions))
                self.trn_q_net_min = tf.summary.scalar(
                    "trn_q_net_min", tf.reduce_min(self.actions))

                # Given you took this action.
                self.action_placeholder = tf.placeholder(name="action",
                                                         dtype=tf.int32,
                                                         shape=[
                                                             None,
                                                         ])

                # Compute Q-Dist. for the action.
                self.action_q_dist = tf.gather_nd(
                    self.q_dist,
                    tf.stack((self.batch_size_range, self.action_placeholder),
                             axis=1))

                self.weighted_q_dist = tf.clip_by_value(
                    self.action_q_dist * self.Z,
                    clip_value_min=params.V_MIN,
                    clip_value_max=params.V_MAX)

                tnd_summary = tf.summary.histogram("training_net_distribution",
                                                   self.weighted_q_dist)

                tnd_mean_summary = tf.summary.scalar(
                    "training_net_distribution_mean",
                    tf.reduce_mean(self.weighted_q_dist))

                # Get target distribution.
                self.m_placeholder = tf.placeholder(dtype=tf.float32,
                                                    shape=(None,
                                                           params.NB_ATOMS),
                                                    name="m_placeholder")
                self.loss_sum = -tf.reduce_sum(
                    self.m_placeholder * tf.log(self.action_q_dist + 1e-5),
                    axis=-1)

                self.loss = tf.reduce_mean(self.loss_sum)

                l_summary = tf.summary.scalar("loss", self.loss)

                self.optimizer = tf.train.AdamOptimizer(
                    learning_rate=params.LEARNING_RATE,
                    epsilon=params.EPSILON_ADAM)
                gradients, variables = zip(
                    *self.optimizer.compute_gradients(self.loss))
                grad_norm_summary = tf.summary.histogram(
                    "grad_norm", tf.global_norm(gradients))
                gradients, _ = tf.clip_by_global_norm(gradients,
                                                      params.GRAD_NORM_CLIP)
                self.train_step = self.optimizer.apply_gradients(
                    zip(gradients, variables))
                self.trnn_summary = tf.summary.merge([
                    tnd_mean_summary, tnd_summary, l_summary,
                    grad_norm_summary, self.trn_q_net_max, self.trn_q_net_mean,
                    self.trn_q_net_min
                ])

    def __init__(self):
        self.num_actions = len(params.GLOBAL_MANAGER.actions)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)
        self.experience_replay = deque(maxlen=params.EXPERIENCE_REPLAY_SIZE)
        with tf.variable_scope("train_net"):
            self.train_net = self.Model(self.sess,
                                        num_actions=self.num_actions,
                                        train_net=True)
        with tf.variable_scope("target_net"):
            self.target_net = self.Model(self.sess,
                                         num_actions=self.num_actions,
                                         train_net=False)
        self.summary = tf.summary.merge_all()
        self.writer = tf.summary.FileWriter("TensorBoardDir")
        init = tf.global_variables_initializer()
        self.sess.run(init)

        main_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                           scope='train_net/common')
        target_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                             scope='target_net/common')

        # I am assuming get_collection returns variables in the same order, please double
        # check this is actually happening

        assign_ops = []
        for main_var, target_var in zip(
                sorted(main_variables, key=lambda x: x.name),
                sorted(target_variables, key=lambda x: x.name)):
            assert (main_var.name.replace("train_net",
                                          "") == target_var.name.replace(
                                              "target_net", ""))
            assign_ops.append(tf.assign(target_var, main_var))

        self.copy_operation = tf.group(*assign_ops)

        self.saver = tf.train.Saver(
            max_to_keep=params.MAX_MODELS_TO_KEEP,
            keep_checkpoint_every_n_hours=params.MIN_MODELS_EVERY_N_HOURS)
        # self.profiler = tf.profiler.Profiler(self.sess.graph)

        self.beholder = Beholder("./TensorBoardDir")

    def act(self, x):
        if np.random.random() < params.EPSILON_START - \
                (params.GLOBAL_MANAGER.timestep / params.EPSILON_FINAL_STEP) * \
                (1 - params.EPSILON_END):
            return np.random.randint(0, self.num_actions)
        else:
            actions = self.sess.run(fetches=self.train_net.actions,
                                    feed_dict={self.train_net.x: x})
            return np.argmax(actions)

    def add(self, x, a, r, x_p, t):
        assert (np.issubdtype(x.dtype, np.integer))
        self.experience_replay.appendleft([x, a, r, x_p, not t])

    def update(self, x, a, r, x_p, t):
        self.add(x, a, r, x_p, t)

        total_loss = 0
        batch_data = random.sample(self.experience_replay, 32)
        batch_x = np.array([i[0] for i in batch_data])
        batch_a = [i[1] for i in batch_data]
        batch_x_p = np.array([
            np.array(
                np.dstack((i[0][:, :, 1:], np.maximum(i[3], i[0][:, :, 3]))))
            for i in batch_data
        ])
        batch_r = [i[2] for i in batch_data]
        batch_t = [i[4] for i in batch_data]

        targn_summary, m, Tz, b, u, l, indexable_u, indexable_l, m_u_vals, m_l_vals, m_u, m_l = self.sess.run(
            [
                self.target_net.targn_summary, self.target_net.m,
                self.target_net.Tz, self.target_net.b, self.target_net.u,
                self.target_net.l, self.target_net.indexable_u,
                self.target_net.indexable_l, self.target_net.m_u_vals,
                self.target_net.m_l_vals, self.target_net.m_u,
                self.target_net.m_l
            ],
            feed_dict={
                self.target_net.x: batch_x_p,
                self.target_net.r: batch_r,
                self.target_net.t: batch_t
            })

        trnn_summary, loss, _ = self.sess.run(
            [
                self.train_net.trnn_summary, self.train_net.loss,
                self.train_net.train_step
            ],
            feed_dict={
                self.train_net.x: batch_x,
                self.train_net.action_placeholder: batch_a,
                self.train_net.m_placeholder: m
            })

        self.writer.add_summary(targn_summary,
                                params.GLOBAL_MANAGER.num_updates)
        self.writer.add_summary(trnn_summary,
                                params.GLOBAL_MANAGER.num_updates)

        total_loss += loss

        self.beholder.update(self.sess,
                             frame=batch_x[0],
                             arrays=[
                                 m, Tz, b, u, l, indexable_u, indexable_l,
                                 m_u_vals, m_l_vals, m_u, m_l
                             ])

        if params.GLOBAL_MANAGER.num_updates > 0 and \
                params.GLOBAL_MANAGER.num_updates % params.COPY_TARGET_FREQ == 0:
            self.sess.run(self.copy_operation)
            print("Copied to target. Current Loss: ", total_loss)

        if params.GLOBAL_MANAGER.num_updates > 0 and \
                params.GLOBAL_MANAGER.num_updates % params.MODEL_SAVE_FREQ == 0:
            self.saver.save(self.sess,
                            "Models/model",
                            global_step=params.GLOBAL_MANAGER.num_updates,
                            write_meta_graph=(params.GLOBAL_MANAGER.num_updates
                                              <= params.MODEL_SAVE_FREQ))
Пример #7
0
        # every 10 steps check accuracy
        if step_count % 10 == 0:
            # get Batch of test data
            batch_test_data, batch_test_labels = dataUtils.getCIFAR10Batch(
                is_eval=True, batch_size=100)

            # do eval step to test accuracy
            test_accuracy, test_loss, summary = sess.run(
                [accuracy, loss, summary_tensor],
                feed_dict={
                    input_placeholder: batch_test_data,
                    label_placeholder: batch_test_labels
                })

            # write data to tensorboard
            test_summary_writer.add_summary(summary, step_count)

            print("Step Count:{}".format(step_count))
            print("Training accuracy: {:.6f} loss: {:.6f}".format(
                training_accuracy, training_loss))
            print("Test accuracy: {:.6f} loss: {:.6f}".format(
                test_accuracy, test_loss))
            beholder.update(session=sess)

        if step_count % 100 == 0:
            save_path = saver.save(sess, "model/model.ckpt")

        # stop training after 1,000 steps
        if step_count > 10000:
            break
Пример #8
0
def train_model(cfg: EmbeddingCfg, model: models.Model, loss: tf.Tensor):
    global_step = tf.train.get_or_create_global_step()
    step = tf.assign_add(global_step, 1)

    learning_rate = tf.train.exponential_decay(
        learning_rate=cfg.init_learning_rate,
        global_step=global_step,
        decay_steps=cfg.lr_decay_steps,
        decay_rate=cfg.lr_decay_rate,
    )
    optimizer = tf.train.AdamOptimizer(learning_rate)
    grads_and_vars = optimizer.compute_gradients(loss)

    train_op = optimizer.apply_gradients(
        [(tf.clip_by_norm(grad, cfg.grad_clipping), var)
         for grad, var in grads_and_vars],
        global_step=global_step
    )

    saver = tf.train.Saver(max_to_keep=10)
    init_op = tf.global_variables_initializer()

    # Basic only train summaries
    summaries = [
        tf.summary.scalar("learning_rate", learning_rate),
        tf.summary.scalar("loss", loss),
    ]

    # Extended validation summaries
    for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        name = var.name.split(":")[0]
        summaries.extend(tensor_default_summaries(name, var))

    for grad, var in grads_and_vars:
        if grad is not None:
            name = var.name.split(":")[0]
            summaries.extend(tensor_default_summaries(name + "/grad", grad))

    merged_summary = tf.summary.merge(summaries)
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    beholder = Beholder(cfg.logdir)
    with tf.Session(config=config) as sess:
        if cfg.debug:
            if cfg.tensorboard_debug is None:
                sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            else:
                sess = tf_debug.TensorBoardDebugWrapperSession(
                    sess, cfg.tensorboard_debug
                )
        summary_writer = tf.summary.FileWriter(
            os.path.join(cfg.logdir), sess.graph
        )
        K.set_session(sess)
        last_check = tf.train.latest_checkpoint(cfg.logdir)
        if last_check is None:
            logger.info(f"Running new checkpoint")
            sess.run(init_op)
        else:
            logger.info(f"Restoring checkpoint {last_check}")
            saver.restore(sess=sess, save_path=last_check)

        gs = sess.run(global_step)
        pbar = trange(gs, cfg.train_steps)
        for i in pbar:
            #  Train hook
            opts = {}
            if cfg.run_trace_every > 0 and i % cfg.run_trace_every == 0:
                opts['options'] = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE
                )
                opts['run_metadata'] = tf.RunMetadata()

            _, _, curr_loss, summary = sess.run([
                step,
                train_op,
                loss,
                merged_summary,
            ], **opts)
            summary_writer.add_summary(summary, i)
            pbar.set_postfix(loss=curr_loss)

            if cfg.run_trace_every > 0 and i % cfg.run_trace_every == 0:
                opts['options'] = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE
                )
                fetched_timeline = timeline.Timeline(
                    opts['run_metadata'].step_stats
                )
                chrome_trace = fetched_timeline.generate_chrome_trace_format(
                    show_memory=True
                )
                with open(
                    os.path.join(cfg.logdir, f'timeline_{i:05}.json'), 'w'
                ) as f:
                    f.write(chrome_trace)
                summary_writer.add_run_metadata(
                    opts['run_metadata'], f"step_{i:05}", global_step=i
                )
                logger.info(
                    f"Saved trace metadata both to timeline_{i:05}.json and step_{i:05} in tensorboard"
                )

            beholder.update(session=sess)

            #  Save hook
            if i % cfg.save_every == 0:
                saver.save(
                    sess=sess,
                    save_path=os.path.join(cfg.logdir, 'model.ckpt'),
                    global_step=global_step
                )
                logger.info(f"Saved new model checkpoint")
        p = os.path.join(cfg.logdir, f"full-model.save")
        model.save(p, overwrite=True, include_optimizer=False)
        logger.info(f"Finished training saved model to {p}")
Пример #9
0
for i in range(5):
    seed = random_sequence_from_textfile(path, maxlen)
    print('-- STARTING RUN NUMBER %s --' & i)
    m.fit(X,
          Y,
          validation_set=0.2,
          batch_size=64,
          n_epoch=1,
          run_id=run_name,
          snapshot_epoch=True)

    print('-- TESTING WITH TEMPERATURE OF %s --' % temp)
    gentext = m.generate(6000, temperature=temp, seq_seed=seed)
    print('-- GENERATION COMPLETED --')

    # Add it to the summary placeholder
    _sess = m.session
    _graph = _sess.graph
    _logdir = m.trainer.summ_writer.get_logdir()
    _step = int(m.trainer.global_step.eval(session=_sess))
    _writer = tf.summary.FileWriter(_logdir, graph=_graph)
    output_summary = _sess.run(summary_op,
                               feed_dict={valid_placeholder: [gentext]})
    _writer.add_summary(output_summary, global_step=_step)
    _writer.flush()
    _writer.close()
    m.trainer.saver.save(_sess, './run/' + model_name + '.ckpt', _step)
    beholder.update(_sess)

m.save(model_name)
Пример #10
0
def train_nn(sess, epochs, batch_size, get_batches_fn, train_op,
             cross_entropy_loss, input_image, correct_label, learning_rate,
             is_training, mean_iou, merged_summary, log_path, save_dir):
    """
    Train neural network and print out the loss during training.
    :param sess: TF Session
    :param epochs: Number of epochs
    :param batch_size: Batch size
    :param get_batches_fn: Function to get batches of training data.  Call using get_batches_fn(batch_size)
    :param train_op: TF Operation to train the neural network
    :param cross_entropy_loss: TF Tensor for the amount of loss
    :param input_image: TF Placeholder for input images
    :param correct_label: TF Placeholder for label images
    :param keep_prob: TF Placeholder for dropout keep probability
    :param learning_rate: TF Placeholder for learning rate
    """

    # create tensorboard session at location log_path and save the graph there
    writer = tf.summary.FileWriter(log_path, graph=sess.graph)
    beholder = Beholder(log_path)

    saver = tf.train.Saver()

    images = []
    labels = []

    # Traing the model
    print("Training")
    for epoch in range(epochs):
        # train with ALL the training data per epoch, Training each pass with
        # batches of data with a batch_size count
        batch = 0
        for images, labels in get_batches_fn(batch_size):
            summary, _, loss = sess.run(
                [merged_summary, train_op, cross_entropy_loss],
                feed_dict={
                    input_image: images,
                    correct_label: labels,
                    is_training: True,
                    learning_rate: 0.0002
                })
            batch += 1

            # add summaries to tensorboard
            writer.add_summary(summary, (epoch + 1) * batch)

            print('Epoch {}, batch: {}, loss: {} '.format(
                epoch + 1, batch, loss))

        # check the accuracy of the model against the validation set
        # validation_accuracy = sess.run(accuracy, feed_dict={x: x_valid_reshape, y:one_hot_valid})
        iou = sess.run([mean_iou],
                       feed_dict={
                           input_image: images,
                           correct_label: labels,
                           is_training: False
                       })
        iou_sum = iou[0][0]

        # print out the models accuracies.
        # to print on the same line, add \r to start of string
        sys.stdout.write("EPOCH {}. IOU = {:.3f}\n".format(epoch + 1, iou_sum))

        beholder.update(session=sess)
        saver.save(sess, save_dir, epoch)

    saver_path = saver.save(sess, save_dir)
    print("Model saved in path: %s" % saver_path)

    writer.close()