class BeholderHook(tf.train.SessionRunHook): """SessionRunHook implementation that runs Beholder every step. Convenient when using tf.train.MonitoredSession: ```python beholder_hook = BeholderHook(LOG_DIRECTORY) with MonitoredSession(..., hooks=[beholder_hook]) as sess: sess.run(train_op) ``` """ def __init__(self, logdir, list_of_np_ndarrays, frame): """Creates new Hook instance Args: logdir: Directory where Beholder should write data. """ self._logdir = logdir self.beholder = None self.list_of_np_ndarrays = list_of_np_ndarrays self.frame = frame def begin(self): self.beholder = Beholder(self._logdir) def before_run(self, run_context): return tf.train.SessionRunArgs(fetches=self.list_of_np_ndarrays) def after_run(self, run_context, run_values): self.beholder.update(session=run_context.session, arrays=run_values.results, frame=self.frame)
class BeholderCallback(tf.keras.callbacks.Callback): def __init__(self, tensor, logdir, sess=None): self.visualizer = Beholder(logdir=logdir) self.sess = sess if sess is None: self.sess = K.get_session() self.tensor = tensor def on_epoch_end(self, epoch, logs=None): frame = self.sess.run( self.tensor ) # depending on the tensor, this might require a feed_dict self.visualizer.update(session=self.sess, frame=frame)
class BeholderCB(tf.keras.callbacks.Callback): """Keras callback for tensorboard beholder plugin: https://github.com/tensorflow/tensorboard/tree/master/tensorboard/plugins/beholder Args: logdir (str): path to the tensorboard log directory. sess: tensorflow session. """ def __init__(self, logdir, sess): super(BeholderCB, self).__init__() self.beholder = Beholder(logdir=logdir) self.session = sess def on_epoch_end(self, epoch, logs=None): super(BeholderCB, self).on_epoch_end(epoch, logs) self.beholder.update(session=self.session)
class RolloutWorker(EvaluatorInterface): """Common experience collection class. This class wraps a policy instance and an environment class to collect experiences from the environment. You can create many replicas of this class as Ray actors to scale RL training. This class supports vectorized and multi-agent policy evaluation (e.g., VectorEnv, MultiAgentEnv, etc.) Examples: >>> # Create a rollout worker and using it to collect experiences. >>> worker = RolloutWorker( ... env_creator=lambda _: gym.make("CartPole-v0"), ... policy=PGTFPolicy) >>> print(worker.sample()) SampleBatch({ "obs": [[...]], "actions": [[...]], "rewards": [[...]], "dones": [[...]], "new_obs": [[...]]}) >>> # Creating a multi-agent rollout worker >>> worker = RolloutWorker( ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), ... policies={ ... # Use an ensemble of two policies for car agents ... "car_policy1": ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}), ... "car_policy2": ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}), ... # Use a single shared policy for all traffic lights ... "traffic_light_policy": ... (PGTFPolicy, Box(...), Discrete(...), {}), ... }, ... policy_mapping_fn=lambda agent_id: ... random.choice(["car_policy1", "car_policy2"]) ... if agent_id.startswith("car_") else "traffic_light_policy") >>> print(worker.sample()) MultiAgentBatch({ "car_policy1": SampleBatch(...), "car_policy2": SampleBatch(...), "traffic_light_policy": SampleBatch(...)}) """ @DeveloperAPI @classmethod def as_remote(cls, num_cpus=None, num_gpus=None, resources=None): return ray.remote(num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls) @DeveloperAPI def __init__(self, env_creator, policy, policy_mapping_fn=None, policies_to_train=None, tf_session_creator=None, batch_steps=100, batch_mode="truncate_episodes", episode_horizon=None, preprocessor_pref="deepmind", sample_async=False, compress_observations=False, num_envs=1, observation_filter="NoFilter", clip_rewards=None, clip_actions=True, env_config=None, model_config=None, policy_config=None, worker_index=0, monitor_path=None, log_dir=None, log_level=None, callbacks=None, input_creator=lambda ioctx: ioctx.default_sampler_input(), input_evaluation=frozenset([]), output_creator=lambda ioctx: NoopOutput(), remote_worker_envs=False, remote_env_batch_wait_ms=0, soft_horizon=False, _fake_sampler=False): """Initialize a rollout worker. Arguments: env_creator (func): Function that returns a gym.Env given an EnvContext wrapped configuration. policy (class|dict): Either a class implementing Policy, or a dictionary of policy id strings to (Policy, obs_space, action_space, config) tuples. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn should also be set. policy_mapping_fn (func): A function that maps agent ids to policy ids in multi-agent mode. This function will be called each time a new agent appears in an episode, to bind that agent to a policy for the duration of the episode. policies_to_train (list): Optional whitelist of policies to train, or None for all policies. tf_session_creator (func): A function that returns a TF session. This is optional and only useful with TFPolicy. batch_steps (int): The target number of env transitions to include in each sample batch returned from this worker. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch of at most `batch_steps * num_envs` in size. The batch will be exactly `batch_steps * num_envs` in size if postprocessing does not change batch sizes. Episodes may be truncated in order to meet this size requirement. "complete_episodes": Each call to sample() will return a batch of at least `batch_steps * num_envs` in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. episode_horizon (int): Whether to stop episodes at this horizon. preprocessor_pref (str): Whether to prefer RLlib preprocessors ("rllib") or deepmind ("deepmind") when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. compress_observations (bool): If true, compress the observations. They can be decompressed with rllib/utils/compression. num_envs (int): If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_filter (str): Name of observation filter to use. clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to experience postprocessing. Setting to None means clip for Atari only. clip_actions (bool): Whether to clip action values to the range specified by the policy action space. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the multi-agent case, this config will be merged with the per-policy configs specified by `policy`. worker_index (int): For remote workers, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. monitor_path (str): Write out episode stats and videos to this directory if specified. log_dir (str): Directory where logs can be placed. log_level (str): Set the root log level on creation. callbacks (dict): Dict of custom debug callbacks. input_creator (func): Function that returns an InputReader object for loading previous generated experiences. input_evaluation (list): How to evaluate the policy performance. This only makes sense to set when the input is reading offline data. The possible values include: - "is": the step-wise importance sampling estimator. - "wis": the weighted step-wise is estimator. - "simulation": run the environment in the background, but use this data for evaluation only and never for learning. output_creator (func): Function that returns an OutputWriter object for saving generated experiences. remote_worker_envs (bool): If using num_envs > 1, whether to create those new envs in remote processes instead of in the current process. This adds overheads, but can make sense if your envs remote_env_batch_wait_ms (float): Timeout that remote workers are waiting when polling environments. 0 (continue when at least one env is ready) is a reasonable default, but optimal value could be obtained by measuring your environment step / reset and model inference perf. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. _fake_sampler (bool): Use a fake (inf speed) sampler for testing. """ global _global_worker _global_worker = self if log_level: logging.getLogger("ray.rllib").setLevel(log_level) if worker_index > 1: disable_log_once_globally() # only need 1 worker to log elif log_level == "DEBUG": enable_periodic_logging() env_context = EnvContext(env_config or {}, worker_index) policy_config = policy_config or {} self.policy_config = policy_config self.callbacks = callbacks or {} self.worker_index = worker_index model_config = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) if not callable(policy_mapping_fn): raise ValueError( "Policy mapping function not callable. If you're using Tune, " "make sure to escape the function with tune.function() " "to prevent it from being evaluated as an expression.") self.env_creator = env_creator self.sample_batch_size = batch_steps * num_envs self.batch_mode = batch_mode self.compress_observations = compress_observations self.preprocessing_enabled = True self.last_batch = None self._fake_sampler = _fake_sampler self._beholder = None self.env = _validate_env(env_creator(env_context)) if isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, BaseEnv): def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ not model_config.get("custom_preprocessor") and \ preprocessor_pref == "deepmind": # Deepmind wrappers already handle all preprocessing self.preprocessing_enabled = False if clip_rewards is None: clip_rewards = True def wrap(env): env = wrap_deepmind(env, dim=model_config.get("dim"), framestack=model_config.get("framestack")) if monitor_path: env = _monitor(env, monitor_path) return env else: def wrap(env): if monitor_path: env = _monitor(env, monitor_path) return env self.env = wrap(self.env) def make_env(vector_index): return wrap( env_creator( env_context.copy_with_overrides( vector_index=vector_index, remote=remote_worker_envs))) self.tf_sess = None policy_dict = _validate_and_canonicalize(policy, self.env) self.policies_to_train = policies_to_train or list(policy_dict.keys()) if _has_tensorflow_graph(policy_dict): if (ray.is_initialized() and ray.worker._mode() != ray.worker.LOCAL_MODE and not ray.get_gpu_ids()): logger.info("Creating policy evaluation worker {}".format( worker_index) + " on CPU (please ignore any CUDA init errors)") if not tf: raise ImportError("Could not import tensorflow") with tf.Graph().as_default(): if tf_session_creator: self.tf_sess = tf_session_creator() else: self.tf_sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.tf_sess.as_default(): self.policy_map, self.preprocessors = \ self._build_policy_map(policy_dict, policy_config) else: self.policy_map, self.preprocessors = self._build_policy_map( policy_dict, policy_config) self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} if self.multiagent: if not ((isinstance(self.env, MultiAgentEnv) or isinstance(self.env, ExternalMultiAgentEnv)) or isinstance(self.env, BaseEnv)): raise ValueError( "Have multiple policies {}, but the env ".format( self.policy_map) + "{} is not a subclass of BaseEnv, MultiAgentEnv or " "ExternalMultiAgentEnv?".format(self.env)) self.filters = { policy_id: get_filter(observation_filter, policy.observation_space.shape) for (policy_id, policy) in self.policy_map.items() } if self.worker_index == 0: logger.info("Built filter map: {}".format(self.filters)) # Always use vector env for consistency even if num_envs = 1 self.async_env = BaseEnv.to_base_env( self.env, make_env=make_env, num_envs=num_envs, remote_envs=remote_worker_envs, remote_env_batch_wait_ms=remote_env_batch_wait_ms) self.num_envs = num_envs if self.batch_mode == "truncate_episodes": unroll_length = batch_steps pack_episodes = True elif self.batch_mode == "complete_episodes": unroll_length = float("inf") # never cut episodes pack_episodes = False # sampler will return 1 episode per poll else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) self.io_context = IOContext(log_dir, policy_config, worker_index, self) self.reward_estimators = [] for method in input_evaluation: if method == "simulation": logger.warning( "Requested 'simulation' input evaluation method: " "will discard all sampler outputs and keep only metrics.") sample_async = True elif method == "is": ise = ImportanceSamplingEstimator.create(self.io_context) self.reward_estimators.append(ise) elif method == "wis": wise = WeightedImportanceSamplingEstimator.create( self.io_context) self.reward_estimators.append(wise) else: raise ValueError( "Unknown evaluation method: {}".format(method)) if sample_async: self.sampler = AsyncSampler(self.async_env, self.policy_map, policy_mapping_fn, self.preprocessors, self.filters, clip_rewards, unroll_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions, blackhole_outputs="simulation" in input_evaluation, soft_horizon=soft_horizon) self.sampler.start() else: self.sampler = SyncSampler(self.async_env, self.policy_map, policy_mapping_fn, self.preprocessors, self.filters, clip_rewards, unroll_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions, soft_horizon=soft_horizon) self.input_reader = input_creator(self.io_context) assert isinstance(self.input_reader, InputReader), self.input_reader self.output_writer = output_creator(self.io_context) assert isinstance(self.output_writer, OutputWriter), self.output_writer logger.debug( "Created rollout worker with env {} ({}), policies {}".format( self.async_env, self.env, self.policy_map)) @override(EvaluatorInterface) def sample(self): """Evaluate the current policies and return a batch of experiences. Return: SampleBatch|MultiAgentBatch from evaluating the current policies. """ if self._fake_sampler and self.last_batch is not None: return self.last_batch if log_once("sample_start"): logger.info("Generating sample batch of size {}".format( self.sample_batch_size)) batches = [self.input_reader.next()] steps_so_far = batches[0].count # In truncate_episodes mode, never pull more than 1 batch per env. # This avoids over-running the target batch size. if self.batch_mode == "truncate_episodes": max_batches = self.num_envs else: max_batches = float("inf") while steps_so_far < self.sample_batch_size and len( batches) < max_batches: batch = self.input_reader.next() steps_so_far += batch.count batches.append(batch) batch = batches[0].concat_samples(batches) if self.callbacks.get("on_sample_end"): self.callbacks["on_sample_end"]({"worker": self, "samples": batch}) # Always do writes prior to compression for consistency and to allow # for better compression inside the writer. self.output_writer.write(batch) # Do off-policy estimation if needed if self.reward_estimators: for sub_batch in batch.split_by_episode(): for estimator in self.reward_estimators: estimator.process(sub_batch) if log_once("sample_end"): logger.info("Completed sample batch:\n\n{}\n".format( summarize(batch))) if self.compress_observations == "bulk": batch.compress(bulk=True) elif self.compress_observations: batch.compress() if self._fake_sampler: self.last_batch = batch return batch @DeveloperAPI @ray.method(num_return_vals=2) def sample_with_count(self): """Same as sample() but returns the count as a separate future.""" batch = self.sample() return batch, batch.count @override(EvaluatorInterface) def get_weights(self, policies=None): if policies is None: policies = self.policy_map.keys() return { pid: policy.get_weights() for pid, policy in self.policy_map.items() if pid in policies } @override(EvaluatorInterface) def set_weights(self, weights): for pid, w in weights.items(): self.policy_map[pid].set_weights(w) @override(EvaluatorInterface) def compute_gradients(self, samples): if log_once("compute_gradients"): logger.info("Compute gradients on:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count if log_once("grad_out"): logger.info("Compute grad info:\n\n{}\n".format( summarize(info_out))) return grad_out, info_out @override(EvaluatorInterface) def apply_gradients(self, grads): if log_once("apply_gradients"): logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { pid: self.policy_map[pid]._build_apply_gradients(builder, grad) for pid, grad in grads.items() } return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) for pid, g in grads.items() } else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) @override(EvaluatorInterface) def learn_on_batch(self, samples): if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") else: builder = None for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue policy = self.policy_map[pid] if builder and hasattr(policy, "_build_learn_on_batch"): to_fetch[pid] = policy._build_learn_on_batch( builder, batch) else: info_out[pid] = policy.learn_on_batch(batch) info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: learn_on_batch_outputs = self.policy_map[ DEFAULT_POLICY_ID].learn_on_batch(samples) if isinstance(learn_on_batch_outputs, tuple): info_out, beholder_arrays = learn_on_batch_outputs else: info_out, beholder_arrays = learn_on_batch_outputs, {} if self.policy_config["evaluation_config"]["beholder"]: with self.tf_sess.graph.as_default(): if self._beholder is None: self._beholder = Beholder(self.io_context.log_dir) self._beholder.update(self.tf_sess, arrays=beholder_arrays or None) if log_once("learn_out"): logger.info("Training output:\n\n{}\n".format(summarize(info_out))) return info_out @DeveloperAPI def get_metrics(self): """Returns a list of new RolloutMetric objects from evaluation.""" out = self.sampler.get_metrics() for m in self.reward_estimators: out.extend(m.get_metrics()) return out @DeveloperAPI def foreach_env(self, func): """Apply the given function to each underlying env instance.""" envs = self.async_env.get_unwrapped() if not envs: return [func(self.async_env)] else: return [func(e) for e in envs] @DeveloperAPI def get_policy(self, policy_id=DEFAULT_POLICY_ID): """Return policy for the specified id, or None. Arguments: policy_id (str): id of policy to return. """ return self.policy_map.get(policy_id) @DeveloperAPI def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): """Apply the given function to the specified policy.""" return func(self.policy_map[policy_id]) @DeveloperAPI def foreach_policy(self, func): """Apply the given function to each (policy, policy_id) tuple.""" return [func(policy, pid) for pid, policy in self.policy_map.items()] @DeveloperAPI def foreach_trainable_policy(self, func): """Apply the given function to each (policy, policy_id) tuple. This only applies func to policies in `self.policies_to_train`.""" return [ func(policy, pid) for pid, policy in self.policy_map.items() if pid in self.policies_to_train ] @DeveloperAPI def sync_filters(self, new_filters): """Changes self's filter to given and rebases any accumulated delta. Args: new_filters (dict): Filters with new state to update local copy. """ assert all(k in new_filters for k in self.filters) for k in self.filters: self.filters[k].sync(new_filters[k]) @DeveloperAPI def get_filters(self, flush_after=False): """Returns a snapshot of filters. Args: flush_after (bool): Clears the filter buffer state. Returns: return_filters (dict): Dict for serializable filters """ return_filters = {} for k, f in self.filters.items(): return_filters[k] = f.as_serializable() if flush_after: f.clear_buffer() return return_filters @DeveloperAPI def save(self): filters = self.get_filters(flush_after=True) state = { pid: self.policy_map[pid].get_state() for pid in self.policy_map } return pickle.dumps({"filters": filters, "state": state}) @DeveloperAPI def restore(self, objs): objs = pickle.loads(objs) self.sync_filters(objs["filters"]) for pid, state in objs["state"].items(): self.policy_map[pid].set_state(state) @DeveloperAPI def set_global_vars(self, global_vars): self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars)) @DeveloperAPI def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): self.policy_map[policy_id].export_model(export_dir) @DeveloperAPI def export_policy_checkpoint(self, export_dir, filename_prefix="model", policy_id=DEFAULT_POLICY_ID): self.policy_map[policy_id].export_checkpoint(export_dir, filename_prefix) @DeveloperAPI def stop(self): self.async_env.stop() def _build_policy_map(self, policy_dict, policy_config): policy_map = {} preprocessors = {} for name, (cls, obs_space, act_space, conf) in sorted(policy_dict.items()): logger.debug("Creating policy for {}".format(name)) merged_conf = merge_dicts(policy_config, conf) if self.preprocessing_enabled: preprocessor = ModelCatalog.get_preprocessor_for_space( obs_space, merged_conf.get("model")) preprocessors[name] = preprocessor obs_space = preprocessor.observation_space else: preprocessors[name] = NoPreprocessor(obs_space) if isinstance(obs_space, gym.spaces.Dict) or \ isinstance(obs_space, gym.spaces.Tuple): raise ValueError( "Found raw Tuple|Dict space as input to policy. " "Please preprocess these observations with a " "Tuple|DictFlatteningPreprocessor.") if tf: with tf.variable_scope(name): policy_map[name] = cls(obs_space, act_space, merged_conf) else: policy_map[name] = cls(obs_space, act_space, merged_conf) if self.worker_index == 0: logger.info("Built policy map: {}".format(policy_map)) logger.info("Built preprocessor map: {}".format(preprocessors)) return policy_map, preprocessors def __del__(self): if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler): self.sampler.shutdown = True
def train(hps, files): ngpus = hps.ngpus config = tf.ConfigProto() if ngpus > 1: try: import horovod.tensorflow as hvd config = tf.ConfigProto() config.gpu_options.visible_device_list = str(hvd.local_rank()) except ImportError: hvd = None print("horovod not available, can only use 1 gpu") ngpus = 1 # todo: organize current_res_w = hps.current_res_w res_multiplier = current_res_w // hps.start_res_w current_res_h = hps.start_res_h * res_multiplier tfrecord_input = any('.tfrecords' in fname for fname in files) # if using tfrecord, assume dataset is duplicated across multiple resolutions if tfrecord_input: num_files = 0 for fname in [fname for fname in files if "res%d" % current_res_w in fname]: for record in tf.compat.v1.python_io.tf_record_iterator(fname): num_files += 1 else: num_files = len(files) label_list = [] total_classes = 0 if hps.label_file: do_cgan = True label_list, total_classes = build_label_list_from_file(hps.label_file) else: do_cgan = False print("dataset has %d files" % num_files) try: batch_size = int(hps.batch_size) try_schedule = False except ValueError: try_schedule = True if try_schedule: batch_schedule = ast.literal_eval(hps.batch_size) else: batch_schedule = None # always generate 32 sample images (should be feasible at high resolutions due to no training) # will probably need to edit for > 128x128 sample_batch = 32 sample_latent_numpy = np.random.normal(0., 1., [sample_batch, 512]) if do_cgan: examples_per_class = sample_batch // total_classes remainder = sample_batch % total_classes sample_cgan_latent_numpy = None for i in range(0, total_classes): class_vector = [0.] * total_classes class_vector[i] = 1. if sample_cgan_latent_numpy is None: sample_cgan_latent_numpy = [class_vector] * (examples_per_class + remainder) else: sample_cgan_latent_numpy += [class_vector] * examples_per_class sample_cgan_latent_numpy = np.array(sample_cgan_latent_numpy) use_beholder = hps.use_beholder if use_beholder: try: from tensorboard.plugins.beholder import Beholder except ImportError: print("Could not import beholder") use_beholder = False while current_res_w <= hps.res_w: if ngpus > 1: hvd.init() print("building graph") if batch_schedule is not None: batch_size = batch_schedule[current_res_w] print("res %d batch size is now %d" % (current_res_w, batch_size)) gen_model, mapping_network, dis_model, sampling_model = \ build_models(hps, current_res_w, use_ema_sampling=True, num_classes=total_classes, label_list=label_list if hps.conditional_type == "acgan" else None) with tf.name_scope("optimizers"): optimizer_d, optimizer_g, optimizer_m = build_optimizers(hps) if ngpus > 1: optimizer_d = hvd.DistributedOptimizer(optimizer_d) optimizer_g = hvd.DistributedOptimizer(optimizer_g) optimizer_m = hvd.DistributedOptimizer(optimizer_m) with tf.name_scope("data"): num_shards = None if ngpus == 1 else ngpus shard_index = None if ngpus == 1 else hvd.rank() it = build_data_iterator(hps, files, current_res_h, current_res_w, batch_size, label_list=label_list, num_shards=num_shards, shard_index=shard_index) next_batch = it.get_next() real_image = next_batch['data'] fake_latent1 = tf.random_normal([batch_size, 512], 0., 1., name="fake_latent") fake_latent2 = tf.random_normal([batch_size, 512], 0., 1., name="fake_latent") fake_label_dict = None real_label_dict = None if do_cgan: fake_label_dict = {} real_label_dict = {} for label in label_list: if hps.cond_uniform_fake: distribution = np.ones_like([label.probabilities]) else: distribution = np.log([label.probabilities]) fake_labels = tf.random.categorical(distribution, batch_size) if label.multi_dim is False: normalized_labels = (fake_labels - tf.reduce_min(fake_labels)) / \ (tf.reduce_max(fake_labels) - tf.reduce_min(fake_labels)) fake_labels = tf.reshape(normalized_labels, [batch_size, 1]) else: fake_labels = tf.reshape(tf.one_hot(fake_labels, label.num_classes), [batch_size, label.num_classes]) fake_label_dict[label.name] = fake_labels real_label_dict[label.name] = next_batch[label.name] #fake_label_list.append(fake_labels) # ideally would handle one dimensional labels differently, theory isn't well supported # for that though (example: categorical values of short, medium, tall are on one dimension) # real_labels = tf.reshape(tf.one_hot(tf.cast(next_batch[label.name], tf.int32), num_classes), # [batch_size, num_classes]) #real_label_list.append(real_labels) fake_label_tensor = tf.concat([fake_label_dict[l] for l in fake_label_dict.keys()], axis=-1) real_label_tensor = tf.concat([real_label_dict[l] for l in real_label_dict.keys()], axis=-1) sample_latent = tf.constant(sample_latent_numpy, dtype=tf.float32, name="sample_latent") if do_cgan: sample_cgan_w = tf.constant(sample_cgan_latent_numpy, dtype=tf.float32, name="sample_cgan_latent") alpha_ph = tf.placeholder(shape=(), dtype=tf.float32, name="alpha") # From Fig 2: "During a resolution transition, # we interpolate between two resolutions of the real images" real_image = real_image*alpha_ph + \ (1-alpha_ph)*upsample(downsample_nv(real_image), method="nearest_neighbor") real_image = upsample(real_image, method='nearest_neighbor', factor=hps.res_w//current_res_w) if do_cgan: with tf.name_scope("gen_synthesis"): fake_image = gen_model(alpha_ph, zs=[fake_latent1, fake_latent2], mapping_network=mapping_network, cgan_w=fake_label_tensor, random_crossover=True) real_logit, real_class_logits = dis_model(real_image, alpha_ph, real_label_tensor if hps.conditional_type == "proj" else None) fake_logit, fake_class_logits = dis_model(fake_image, alpha_ph, fake_label_tensor if hps.conditional_type == "proj" else None) else: with tf.name_scope("gen_synthesis"): fake_image = gen_model(alpha_ph, zs=[fake_latent1, fake_latent2], mapping_network=mapping_network, random_crossover=True) real_logit, real_class_logits = dis_model(real_image, alpha_ph) # todo: make work with other labels fake_logit, fake_class_logits = dis_model(fake_image, alpha_ph) with tf.name_scope("gen_sampling"): average_latent = tf.constant(np.random.normal(0., 1., [10000, 512]), dtype=tf.float32) low_psi = 0.20 if hps.map_cond: class_vector = [0.] * total_classes class_vector[0] = 1. # one hot encoding average_w = tf.reduce_mean(mapping_network(tf.concat([average_latent, [class_vector]*10000], axis=-1)), axis=0) sample_latent_lowpsi = average_w + low_psi * \ (mapping_network(tf.concat([sample_latent, [class_vector]*sample_batch], axis=-1)) - average_w) else: average_w = tf.reduce_mean(mapping_network(average_latent), axis=0) sample_latent_lowpsi = average_w + low_psi * (mapping_network(sample_latent) - average_w) average_w_batch = tf.tile(tf.reshape(average_w, [1, 512]), [sample_batch, 1]) if do_cgan: sample_img_lowpsi = sampling_model(alpha_ph, intermediate_ws=sample_latent_lowpsi, cgan_w=sample_cgan_w) sample_img_base = sampling_model(alpha_ph, zs=sample_latent, mapping_network=mapping_network, cgan_w=sample_cgan_w) sample_img_mode = sampling_model(alpha_ph, intermediate_ws=average_w_batch, cgan_w=sample_cgan_w) sample_img_mode = tf.concat([sample_img_mode[0:2] + sample_img_mode[-3:-1]], axis=0) else: sample_img_lowpsi = sampling_model(alpha_ph, intermediate_ws=sample_latent_lowpsi) sample_img_base = sampling_model(alpha_ph, zs=sample_latent, mapping_network=mapping_network) sample_img_mode = sampling_model(alpha_ph, intermediate_ws=average_w_batch)[0:4] sample_images = tf.concat([sample_img_lowpsi, sample_img_mode, sample_img_base], axis=0) sampling_model_init_ops = weight_following_ema_ops(average_model=sampling_model, reference_model=gen_model) #sample_img_base = gen_model(sample_latent, alpha_ph, mapping_network) with tf.name_scope("loss"): loss_discriminator, loss_generator = hps.loss_fn(real_logit, fake_logit) if real_class_logits is not None: for label in label_list: label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=next_batch[label.name], logits=real_class_logits[label.name]) loss_discriminator += label_loss * hps.cond_weight * 1./(len(label_list)) tf.summary.scalar("label_loss_real", tf.reduce_mean(label_loss)) if fake_class_logits is not None: for label in label_list: label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=fake_label_dict[label.name], logits=fake_class_logits[label.name]) loss_discriminator += label_loss * hps.cond_weight * 1./(len(label_list)) tf.summary.scalar("label_loss_fake", tf.reduce_mean(label_loss)) loss_generator += label_loss * hps.cond_weight * 1./(len(label_list)) if hps.gp_fn: gp = hps.gp_fn(fake_image, real_image, dis_model, alpha_ph, real_label_dict, conditional_type=hps.conditional_type) tf.summary.scalar("gradient_penalty", tf.reduce_mean(gp)) loss_discriminator += hps.lambda_gp*gp dp = drift_penalty(real_logit) tf.summary.scalar("drift_penalty", tf.reduce_mean(dp)) if hps.lambda_drift != 0.: loss_discriminator = tf.expand_dims(loss_discriminator, -1) + hps.lambda_drift * dp loss_discriminator_avg = tf.reduce_mean(loss_discriminator) loss_generator_avg = tf.reduce_mean(loss_generator) with tf.name_scope("train"): train_step_d = optimizer_d.minimize(loss_discriminator_avg, var_list=dis_model.trainable_variables) # todo: test this with tf.control_dependencies(weight_following_ema_ops(average_model=sampling_model, reference_model=gen_model)): train_step_g = [optimizer_g.minimize(loss_generator_avg, var_list=gen_model.trainable_variables)] if hps.do_mapping_network: train_step_g.append( optimizer_m.minimize(loss_generator_avg, var_list=mapping_network.trainable_variables)) with tf.name_scope("summary"): tf.summary.histogram("real_scores", real_logit) tf.summary.scalar("loss_discriminator", loss_discriminator_avg) tf.summary.scalar("loss_generator", loss_generator_avg) tf.summary.scalar("real_logit", tf.reduce_mean(real_logit)) tf.summary.scalar("fake_logit", tf.reduce_mean(fake_logit)) tf.summary.histogram("real_logit", real_logit) tf.summary.histogram("fake_logit", fake_logit) tf.summary.scalar("alpha", alpha_ph) merged = tf.summary.merge_all() image_summary_real = generate_image_summary(real_image, "real") image_summary_fake_avg = generate_image_summary(sample_images, "fake_avg") #image_summary_fake = generate_image_summary(sample_img_base, "fake") global_step = tf.train.get_or_create_global_step() if hps.profile: builder = tf.profiler.ProfileOptionBuilder opts = builder(builder.time_and_memory()).order_by('micros').build() with tf.contrib.tfprof.ProfileContext(hps.model_dir, trace_steps=[], dump_steps=[]) as pctx: with tf.Session(config=config) as sess: #if hps.tboard_debug: # sess = tf_debug.TensorBoardDebugWrapperSession(sess, "localhost:6064") #elif hps.cli_debug: # sess = tf_debug.LocalCLIDebugWrapperSession(sess) sess.run(tf.global_variables_initializer()) sess.run(sampling_model_init_ops) alpha = 1. step = 0 if os.path.exists(hps.save_paths.gen_model) and os.path.exists(hps.save_paths.dis_model): if ngpus == 1 or hvd.rank() == 0: print("restoring") restore_models_and_optimizers(sess, gen_model, dis_model, mapping_network, sampling_model, optimizer_g, optimizer_d, optimizer_m, hps.save_paths) if os.path.exists(hps.save_paths.alpha) and os.path.exists(hps.save_paths.step): alpha, step = restore_alpha_and_step(hps.save_paths) print("alpha") print(alpha) if alpha != 1.: alpha_inc = 1. / (hps.epochs_per_res * (num_files / batch_size)) else: alpha_inc = 0. writer_path = \ os.path.join(hps.model_dir, "summary_%d" % current_res_w, "alpha_start_%d" % alpha) if use_beholder: beholder = Beholder(writer_path) writer = tf.summary.FileWriter(writer_path, sess.graph) writer.add_summary(image_summary_real.eval(feed_dict={alpha_ph: alpha}), step) print("Starting res %d training" % current_res_w) t = trange(hps.epochs_per_res * num_files // batch_size, desc='Training') if ngpus > 1: sess.run(hvd.broadcast_global_variables(0)) for phase_step in t: try: for i in range(0, hps.ncritic): if hps.profile: pctx.trace_next_step() pctx.dump_next_step() if step % 5 == 0: summary, ld, _ = sess.run([merged, loss_discriminator_avg, train_step_d if not hps.no_train else tf.no_op()], feed_dict={alpha_ph: alpha}) writer.add_summary(summary, step) else: ld, _ = sess.run([loss_discriminator_avg, train_step_d if not hps.no_train else tf.no_op()], feed_dict={alpha_ph: alpha}) if hps.profile: pctx.profiler.profile_operations(options=opts) if hps.profile: pctx.trace_next_step() pctx.dump_next_step() lg, _ = sess.run([loss_generator_avg, train_step_g if not hps.no_train else tf.no_op()], feed_dict={alpha_ph: alpha}) if hps.profile: pctx.profiler.profile_operations(options=opts) alpha = min(alpha+alpha_inc, 1.) #print("step: %d" % step) #print("loss_d: %f" % ld) #print("loss_g: %f\n" % lg) t.set_description('Overall step %d, loss d %f, loss g %f' % (step+1, ld, lg)) if use_beholder: try: beholder.update(session=sess) except Exception as e: print("Beholder failed: " + str(e)) use_beholder = False if phase_step < 5 or (phase_step < 500 and phase_step % 10 == 0) or (step % 1000 == 0): writer.add_summary(image_summary_fake_avg.eval( feed_dict={alpha_ph: alpha}), step) #writer.add_summary(image_summary_fake.eval( # feed_dict={alpha_ph: alpha}), step) if hps.steps_per_save is not None and step % hps.steps_per_save == 0 and (ngpus == 1 or hvd.rank() == 0): save_models_and_optimizers(sess, gen_model, dis_model, mapping_network, sampling_model, optimizer_g, optimizer_d, optimizer_m, hps.save_paths) save_alpha_and_step(1. if alpha_inc != 0. else 0., step, hps.save_paths) step += 1 except tf.errors.OutOfRangeError: break assert (abs(alpha - 1.) < .1), "Alpha should be close to 1., not %f" % alpha # alpha close to 1. (dataset divisible by batch_size for small sets) if ngpus == 1 or hvd.rank() == 0: print(1. if alpha_inc != 0. else 0.) save_models_and_optimizers(sess, gen_model, dis_model, mapping_network, sampling_model, optimizer_g, optimizer_d, optimizer_m, hps.save_paths) backup_model_for_this_phase(hps.save_paths, writer_path) save_alpha_and_step(1. if alpha_inc != 0. else 0., step, hps.save_paths) # Will generate Out of range errors, see if it's easy to save a tensor so get_next() doesn't need # a new value #writer.add_summary(image_summary_real.eval(feed_dict={alpha_ph: 1.}), step) #writer.add_summary(image_summary_fake.eval(feed_dict={alpha_ph: 1.}), step) tf.reset_default_graph() if alpha_inc == 0: current_res_h *= 2 current_res_w *= 2
class C51Agent(): class Model(): def __init__(self, session, num_actions, train_net): self.sess = session # Input self.x = tf.placeholder(name="state", dtype=tf.uint8, shape=(None, params.STATE_DIMENSIONS[0], params.STATE_DIMENSIONS[1], params.HISTORY_LEN)) self.normalized_x = tf.cast(self.x, dtype=tf.float32) / 255.0 with tf.variable_scope("common"): # Convolutional Layers self.conv_outputs = [] for CONV_LAYER_SPEC in params.CONVOLUTIONAL_LAYERS_SPEC: self.conv_outputs.append( tf.layers.conv2d( name="conv_layer_" + str(len(self.conv_outputs) + 1), inputs=self.normalized_x if len(self.conv_outputs) == 0 else self.conv_outputs[-1], filters=CONV_LAYER_SPEC["filters"], kernel_size=CONV_LAYER_SPEC["kernel_size"], strides=CONV_LAYER_SPEC["strides"], activation=tf.nn.relu)) # Flatten self.flattened_conv_output = tf.layers.flatten( name="conv_output_flattener", inputs=self.conv_outputs[-1]) # Hidden Layer self.dense_outputs = [] for DENSE_LAYER_SPEC in params.DENSE_LAYERS_SPEC: self.dense_outputs.append( tf.layers.dense(name="dense_layer_" + str(len(self.dense_outputs) + 1), inputs=self.flattened_conv_output if len(self.dense_outputs) == 0 else self.dense_outputs[-1], units=DENSE_LAYER_SPEC, activation=tf.nn.relu)) # State-Action-Value Distributions (as a flattened vector) self.flattened_q_dist = tf.layers.dense( name="flattened_action_value_dist_logits", inputs=self.dense_outputs[-1], units=num_actions * params.NB_ATOMS) # Unflatten self.q_dist_logits = tf.reshape( self.flattened_q_dist, [-1, num_actions, params.NB_ATOMS], name="reshape_q_dist_logits") # Softmax State-Action-Value Distributions (per action) self.q_dist = tf.nn.softmax(self.q_dist_logits, name="action_value_dist", axis=-1) # Multiply bin probabilities by value self.delta_z = (params.V_MAX - params.V_MIN) / (params.NB_ATOMS - 1) self.Z = tf.range(start=params.V_MIN, limit=params.V_MAX + self.delta_z, delta=self.delta_z) self.post_mul = self.q_dist * tf.reshape( self.Z, [1, 1, params.NB_ATOMS]) # Take sum to get the expected state-action values for each action self.actions = tf.reduce_sum(self.post_mul, axis=2) self.batch_size_range = tf.range(start=0, limit=tf.shape(self.x)[0]) if not train_net: self.targ_q_net_max = tf.summary.scalar( "targ_q_net_max", tf.reduce_max(self.actions)) self.targ_q_net_mean = tf.summary.scalar( "targ_q_net_mean", tf.reduce_mean(self.actions)) self.targ_q_net_min = tf.summary.scalar( "targ_q_net_min", tf.reduce_min(self.actions)) # Find argmax action given expected state-action values at next state self.argmax_action = tf.argmax(self.actions, axis=-1, output_type=tf.int32) # Get it's corresponding distribution (this is the target distribution) self.argmax_action_distribution = tf.gather_nd( self.q_dist, tf.stack((self.batch_size_range, self.argmax_action), axis=1)) # Axis = 1 => [N, 2] self.mean_argmax_next_state_value = tf.summary.scalar( "mean_argmax_q_target", tf.reduce_mean(self.Z * self.argmax_action_distribution)) # Placeholder for reward self.r = tf.placeholder(name="reward", dtype=tf.float32, shape=(None, )) self.t = tf.placeholder(name="terminal", dtype=tf.uint8, shape=(None, )) # Compute Tz (Bellman Operator) on atom of expected state-action-value # r + gamma * z clipped to [V_min, V_max] self.Tz = tf.clip_by_value( tf.reshape(self.r, [-1, 1]) + 0.99 * tf.cast(tf.reshape(self.t, [-1, 1]), tf.float32) * self.Z, clip_value_min=params.V_MIN, clip_value_max=params.V_MAX) # Compute bin number (will be floating point). self.b = (self.Tz - params.V_MIN) / self.delta_z # Lower and Upper Bins. self.l = tf.floor(self.b) self.u = tf.ceil(self.b) # Add weight to the lower bin based on distance from upper bin to # approximate bin index b. (0--b--1. If b = 0.3. Then, assign bin # 0, p(b) * 0.7 weight and bin 1, p(Z = z_b) * 0.3 weight.) self.indexable_l = tf.stack( ( tf.reshape(self.batch_size_range, [-1, 1]) * tf.ones( (1, params.NB_ATOMS), dtype=tf.int32), # BATCH_SIZE_RANGE x NB_ATOMS [[0, ...], [1, ...], ...] tf.cast(self.l, dtype=tf.int32)), axis=-1) self.m_l_vals = self.argmax_action_distribution * (self.u - self.b) self.m_l = tf.scatter_nd(tf.reshape(self.indexable_l, [-1, 2]), tf.reshape(self.m_l_vals, [-1]), tf.shape(self.l)) # Add weight to the lower bin based on distance from upper bin to # approximate bin index b. self.indexable_u = tf.stack( ( tf.reshape(self.batch_size_range, [-1, 1]) * tf.ones( (1, params.NB_ATOMS), dtype=tf.int32), # BATCH_SIZE_RANGE x NB_ATOMS [[0, ...], [1, ...], ...] tf.cast(self.u, dtype=tf.int32)), axis=-1) self.m_u_vals = self.argmax_action_distribution * (self.b - self.l) self.m_u = tf.scatter_nd(tf.reshape(self.indexable_u, [-1, 2]), tf.reshape(self.m_u_vals, [-1]), tf.shape(self.u)) # Add Contributions of both upper and lower parts and # stop gradient to not update the target network. self.m = tf.stop_gradient(tf.squeeze(self.m_l + self.m_u)) self.weighted_m = tf.clip_by_value(self.m * self.Z, clip_value_min=params.V_MIN, clip_value_max=params.V_MAX) self.weighted_m_mean = tf.summary.scalar( "mean_q_target", tf.reduce_mean(self.weighted_m)) self.targ_dist = tf.summary.histogram("target_distribution", self.weighted_m) self.targn_summary = tf.summary.merge([ self.targ_dist, self.weighted_m_mean, self.targ_q_net_max, self.targ_q_net_mean, self.targ_q_net_min, self.mean_argmax_next_state_value ]) else: self.trn_q_net_max = tf.summary.scalar( "trn_q_net_max", tf.reduce_max(self.actions)) self.trn_q_net_mean = tf.summary.scalar( "trn_q_net_mean", tf.reduce_mean(self.actions)) self.trn_q_net_min = tf.summary.scalar( "trn_q_net_min", tf.reduce_min(self.actions)) # Given you took this action. self.action_placeholder = tf.placeholder(name="action", dtype=tf.int32, shape=[ None, ]) # Compute Q-Dist. for the action. self.action_q_dist = tf.gather_nd( self.q_dist, tf.stack((self.batch_size_range, self.action_placeholder), axis=1)) self.weighted_q_dist = tf.clip_by_value( self.action_q_dist * self.Z, clip_value_min=params.V_MIN, clip_value_max=params.V_MAX) tnd_summary = tf.summary.histogram("training_net_distribution", self.weighted_q_dist) tnd_mean_summary = tf.summary.scalar( "training_net_distribution_mean", tf.reduce_mean(self.weighted_q_dist)) # Get target distribution. self.m_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, params.NB_ATOMS), name="m_placeholder") self.loss_sum = -tf.reduce_sum( self.m_placeholder * tf.log(self.action_q_dist + 1e-5), axis=-1) self.loss = tf.reduce_mean(self.loss_sum) l_summary = tf.summary.scalar("loss", self.loss) self.optimizer = tf.train.AdamOptimizer( learning_rate=params.LEARNING_RATE, epsilon=params.EPSILON_ADAM) gradients, variables = zip( *self.optimizer.compute_gradients(self.loss)) grad_norm_summary = tf.summary.histogram( "grad_norm", tf.global_norm(gradients)) gradients, _ = tf.clip_by_global_norm(gradients, params.GRAD_NORM_CLIP) self.train_step = self.optimizer.apply_gradients( zip(gradients, variables)) self.trnn_summary = tf.summary.merge([ tnd_mean_summary, tnd_summary, l_summary, grad_norm_summary, self.trn_q_net_max, self.trn_q_net_mean, self.trn_q_net_min ]) def __init__(self): self.num_actions = len(params.GLOBAL_MANAGER.actions) config = tf.ConfigProto() config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) self.experience_replay = deque(maxlen=params.EXPERIENCE_REPLAY_SIZE) with tf.variable_scope("train_net"): self.train_net = self.Model(self.sess, num_actions=self.num_actions, train_net=True) with tf.variable_scope("target_net"): self.target_net = self.Model(self.sess, num_actions=self.num_actions, train_net=False) self.summary = tf.summary.merge_all() self.writer = tf.summary.FileWriter("TensorBoardDir") init = tf.global_variables_initializer() self.sess.run(init) main_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='train_net/common') target_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net/common') # I am assuming get_collection returns variables in the same order, please double # check this is actually happening assign_ops = [] for main_var, target_var in zip( sorted(main_variables, key=lambda x: x.name), sorted(target_variables, key=lambda x: x.name)): assert (main_var.name.replace("train_net", "") == target_var.name.replace( "target_net", "")) assign_ops.append(tf.assign(target_var, main_var)) self.copy_operation = tf.group(*assign_ops) self.saver = tf.train.Saver( max_to_keep=params.MAX_MODELS_TO_KEEP, keep_checkpoint_every_n_hours=params.MIN_MODELS_EVERY_N_HOURS) # self.profiler = tf.profiler.Profiler(self.sess.graph) self.beholder = Beholder("./TensorBoardDir") def act(self, x): if np.random.random() < params.EPSILON_START - \ (params.GLOBAL_MANAGER.timestep / params.EPSILON_FINAL_STEP) * \ (1 - params.EPSILON_END): return np.random.randint(0, self.num_actions) else: actions = self.sess.run(fetches=self.train_net.actions, feed_dict={self.train_net.x: x}) return np.argmax(actions) def add(self, x, a, r, x_p, t): assert (np.issubdtype(x.dtype, np.integer)) self.experience_replay.appendleft([x, a, r, x_p, not t]) def update(self, x, a, r, x_p, t): self.add(x, a, r, x_p, t) total_loss = 0 batch_data = random.sample(self.experience_replay, 32) batch_x = np.array([i[0] for i in batch_data]) batch_a = [i[1] for i in batch_data] batch_x_p = np.array([ np.array( np.dstack((i[0][:, :, 1:], np.maximum(i[3], i[0][:, :, 3])))) for i in batch_data ]) batch_r = [i[2] for i in batch_data] batch_t = [i[4] for i in batch_data] targn_summary, m, Tz, b, u, l, indexable_u, indexable_l, m_u_vals, m_l_vals, m_u, m_l = self.sess.run( [ self.target_net.targn_summary, self.target_net.m, self.target_net.Tz, self.target_net.b, self.target_net.u, self.target_net.l, self.target_net.indexable_u, self.target_net.indexable_l, self.target_net.m_u_vals, self.target_net.m_l_vals, self.target_net.m_u, self.target_net.m_l ], feed_dict={ self.target_net.x: batch_x_p, self.target_net.r: batch_r, self.target_net.t: batch_t }) trnn_summary, loss, _ = self.sess.run( [ self.train_net.trnn_summary, self.train_net.loss, self.train_net.train_step ], feed_dict={ self.train_net.x: batch_x, self.train_net.action_placeholder: batch_a, self.train_net.m_placeholder: m }) self.writer.add_summary(targn_summary, params.GLOBAL_MANAGER.num_updates) self.writer.add_summary(trnn_summary, params.GLOBAL_MANAGER.num_updates) total_loss += loss self.beholder.update(self.sess, frame=batch_x[0], arrays=[ m, Tz, b, u, l, indexable_u, indexable_l, m_u_vals, m_l_vals, m_u, m_l ]) if params.GLOBAL_MANAGER.num_updates > 0 and \ params.GLOBAL_MANAGER.num_updates % params.COPY_TARGET_FREQ == 0: self.sess.run(self.copy_operation) print("Copied to target. Current Loss: ", total_loss) if params.GLOBAL_MANAGER.num_updates > 0 and \ params.GLOBAL_MANAGER.num_updates % params.MODEL_SAVE_FREQ == 0: self.saver.save(self.sess, "Models/model", global_step=params.GLOBAL_MANAGER.num_updates, write_meta_graph=(params.GLOBAL_MANAGER.num_updates <= params.MODEL_SAVE_FREQ))
# every 10 steps check accuracy if step_count % 10 == 0: # get Batch of test data batch_test_data, batch_test_labels = dataUtils.getCIFAR10Batch( is_eval=True, batch_size=100) # do eval step to test accuracy test_accuracy, test_loss, summary = sess.run( [accuracy, loss, summary_tensor], feed_dict={ input_placeholder: batch_test_data, label_placeholder: batch_test_labels }) # write data to tensorboard test_summary_writer.add_summary(summary, step_count) print("Step Count:{}".format(step_count)) print("Training accuracy: {:.6f} loss: {:.6f}".format( training_accuracy, training_loss)) print("Test accuracy: {:.6f} loss: {:.6f}".format( test_accuracy, test_loss)) beholder.update(session=sess) if step_count % 100 == 0: save_path = saver.save(sess, "model/model.ckpt") # stop training after 1,000 steps if step_count > 10000: break
def train_model(cfg: EmbeddingCfg, model: models.Model, loss: tf.Tensor): global_step = tf.train.get_or_create_global_step() step = tf.assign_add(global_step, 1) learning_rate = tf.train.exponential_decay( learning_rate=cfg.init_learning_rate, global_step=global_step, decay_steps=cfg.lr_decay_steps, decay_rate=cfg.lr_decay_rate, ) optimizer = tf.train.AdamOptimizer(learning_rate) grads_and_vars = optimizer.compute_gradients(loss) train_op = optimizer.apply_gradients( [(tf.clip_by_norm(grad, cfg.grad_clipping), var) for grad, var in grads_and_vars], global_step=global_step ) saver = tf.train.Saver(max_to_keep=10) init_op = tf.global_variables_initializer() # Basic only train summaries summaries = [ tf.summary.scalar("learning_rate", learning_rate), tf.summary.scalar("loss", loss), ] # Extended validation summaries for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): name = var.name.split(":")[0] summaries.extend(tensor_default_summaries(name, var)) for grad, var in grads_and_vars: if grad is not None: name = var.name.split(":")[0] summaries.extend(tensor_default_summaries(name + "/grad", grad)) merged_summary = tf.summary.merge(summaries) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True beholder = Beholder(cfg.logdir) with tf.Session(config=config) as sess: if cfg.debug: if cfg.tensorboard_debug is None: sess = tf_debug.LocalCLIDebugWrapperSession(sess) else: sess = tf_debug.TensorBoardDebugWrapperSession( sess, cfg.tensorboard_debug ) summary_writer = tf.summary.FileWriter( os.path.join(cfg.logdir), sess.graph ) K.set_session(sess) last_check = tf.train.latest_checkpoint(cfg.logdir) if last_check is None: logger.info(f"Running new checkpoint") sess.run(init_op) else: logger.info(f"Restoring checkpoint {last_check}") saver.restore(sess=sess, save_path=last_check) gs = sess.run(global_step) pbar = trange(gs, cfg.train_steps) for i in pbar: # Train hook opts = {} if cfg.run_trace_every > 0 and i % cfg.run_trace_every == 0: opts['options'] = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE ) opts['run_metadata'] = tf.RunMetadata() _, _, curr_loss, summary = sess.run([ step, train_op, loss, merged_summary, ], **opts) summary_writer.add_summary(summary, i) pbar.set_postfix(loss=curr_loss) if cfg.run_trace_every > 0 and i % cfg.run_trace_every == 0: opts['options'] = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE ) fetched_timeline = timeline.Timeline( opts['run_metadata'].step_stats ) chrome_trace = fetched_timeline.generate_chrome_trace_format( show_memory=True ) with open( os.path.join(cfg.logdir, f'timeline_{i:05}.json'), 'w' ) as f: f.write(chrome_trace) summary_writer.add_run_metadata( opts['run_metadata'], f"step_{i:05}", global_step=i ) logger.info( f"Saved trace metadata both to timeline_{i:05}.json and step_{i:05} in tensorboard" ) beholder.update(session=sess) # Save hook if i % cfg.save_every == 0: saver.save( sess=sess, save_path=os.path.join(cfg.logdir, 'model.ckpt'), global_step=global_step ) logger.info(f"Saved new model checkpoint") p = os.path.join(cfg.logdir, f"full-model.save") model.save(p, overwrite=True, include_optimizer=False) logger.info(f"Finished training saved model to {p}")
for i in range(5): seed = random_sequence_from_textfile(path, maxlen) print('-- STARTING RUN NUMBER %s --' & i) m.fit(X, Y, validation_set=0.2, batch_size=64, n_epoch=1, run_id=run_name, snapshot_epoch=True) print('-- TESTING WITH TEMPERATURE OF %s --' % temp) gentext = m.generate(6000, temperature=temp, seq_seed=seed) print('-- GENERATION COMPLETED --') # Add it to the summary placeholder _sess = m.session _graph = _sess.graph _logdir = m.trainer.summ_writer.get_logdir() _step = int(m.trainer.global_step.eval(session=_sess)) _writer = tf.summary.FileWriter(_logdir, graph=_graph) output_summary = _sess.run(summary_op, feed_dict={valid_placeholder: [gentext]}) _writer.add_summary(output_summary, global_step=_step) _writer.flush() _writer.close() m.trainer.saver.save(_sess, './run/' + model_name + '.ckpt', _step) beholder.update(_sess) m.save(model_name)
def train_nn(sess, epochs, batch_size, get_batches_fn, train_op, cross_entropy_loss, input_image, correct_label, learning_rate, is_training, mean_iou, merged_summary, log_path, save_dir): """ Train neural network and print out the loss during training. :param sess: TF Session :param epochs: Number of epochs :param batch_size: Batch size :param get_batches_fn: Function to get batches of training data. Call using get_batches_fn(batch_size) :param train_op: TF Operation to train the neural network :param cross_entropy_loss: TF Tensor for the amount of loss :param input_image: TF Placeholder for input images :param correct_label: TF Placeholder for label images :param keep_prob: TF Placeholder for dropout keep probability :param learning_rate: TF Placeholder for learning rate """ # create tensorboard session at location log_path and save the graph there writer = tf.summary.FileWriter(log_path, graph=sess.graph) beholder = Beholder(log_path) saver = tf.train.Saver() images = [] labels = [] # Traing the model print("Training") for epoch in range(epochs): # train with ALL the training data per epoch, Training each pass with # batches of data with a batch_size count batch = 0 for images, labels in get_batches_fn(batch_size): summary, _, loss = sess.run( [merged_summary, train_op, cross_entropy_loss], feed_dict={ input_image: images, correct_label: labels, is_training: True, learning_rate: 0.0002 }) batch += 1 # add summaries to tensorboard writer.add_summary(summary, (epoch + 1) * batch) print('Epoch {}, batch: {}, loss: {} '.format( epoch + 1, batch, loss)) # check the accuracy of the model against the validation set # validation_accuracy = sess.run(accuracy, feed_dict={x: x_valid_reshape, y:one_hot_valid}) iou = sess.run([mean_iou], feed_dict={ input_image: images, correct_label: labels, is_training: False }) iou_sum = iou[0][0] # print out the models accuracies. # to print on the same line, add \r to start of string sys.stdout.write("EPOCH {}. IOU = {:.3f}\n".format(epoch + 1, iou_sum)) beholder.update(session=sess) saver.save(sess, save_dir, epoch) saver_path = saver.save(sess, save_dir) print("Model saved in path: %s" % saver_path) writer.close()