def add_meta_information(self, outputs): """ Adds features, placeholders, and outputs to the meta-graph :param outputs: A dictionary of outputs (e.g. {'training_op': ..., 'training_outputs': ...}) :return: Nothing, but adds those items to the meta-graph """ from diplomacy_research.utils.tensorflow import tf # Storing features for feature in self.features: cached_features = tf.get_collection('feature_{}'.format(feature)) if not cached_features: tf.add_to_collection('feature_{}'.format(feature), self.features[feature]) # Storing placeholders for placeholder in self.placeholders: cached_placeholders = tf.get_collection( 'placeholder_{}'.format(placeholder)) if not cached_placeholders: tf.add_to_collection('placeholder_{}'.format(placeholder), self.placeholders[placeholder]) # Storing outputs in model for output_name in outputs: self.outputs[output_name] = outputs[output_name] self.outputs['tag/commit_hash'] = GIT_COMMIT_HASH self.outputs['is_trainable'] = True self.outputs['iterator_resource'] = self.iterator_resource # Labeling outputs on meta-graph # Clearing collection if we want to re-add the key # Avoiding creating key if already present for output_tag in self.outputs: if output_tag in outputs and tf.get_collection(output_tag): tf.get_default_graph().clear_collection(output_tag) if self.outputs[output_tag] is not None \ and not output_tag.startswith('_') \ and not output_tag.endswith('_ta') \ and not tf.get_collection(output_tag): tf.add_to_collection(output_tag, self.outputs[output_tag]) # Storing hparams for hparam_name, hparam_value in self.hparams.items(): if not tf.get_collection('tag/hparam/{}'.format(hparam_name)): tf.add_to_collection('tag/hparam/{}'.format(hparam_name), str(hparam_value))
def create_adapter(trainer): """ Creates an adapter (for the learner) :param trainer: A reinforcement learning trainer instance. :type trainer: diplomacy_research.models.training.reinforcement.trainer.ReinforcementTrainer """ from diplomacy_research.utils.tensorflow import tf trainer.adapter = trainer.adapter_constructor(trainer.queue_dataset, graph=tf.get_default_graph())
def _get_version_step(): """ Gets the version step tensor """ from diplomacy_research.utils.tensorflow import tf graph = tf.get_default_graph() version_step_tensors = graph.get_collection(VERSION_STEP) if not version_step_tensors: return None if len(version_step_tensors) == 1: return version_step_tensors[0] raise RuntimeError('Multiple version step tensors defined')
def save_version_model(trainer, version_id=None, wait_for_grpc=False): """ Builds a saved model for the given graph for inference :param trainer: A reinforcement learning trainer instance. :param version_id: Optional. Integer. The version id of the graph to save. Defaults to the current version. :param wait_for_grpc: If true, waits for the target to be loaded on the localhost TF Serving before returning. :return: Nothing :type trainer: diplomacy_research.models.training.reinforcement.trainer.ReinforcementTrainer """ from diplomacy_research.models.training.reinforcement.serving import wait_for_version from diplomacy_research.utils.tensorflow import tf # Can only save model in standalone mode or while chief in distributed mode if trainer.cluster_config and not trainer.cluster_config.is_chief: return # Creating output directory if version_id is None: version_id = get_version_id(trainer) output_dir = get_serving_directory(trainer, 'player') os.makedirs(output_dir, exist_ok=True) # Saving model proto_fields = trainer.dataset_builder.get_proto_fields() trainer.run_func_without_hooks( trainer.session, lambda _sess: build_saved_model(saved_model_dir=output_dir, version_id=version_id, signature=trainer.signature, proto_fields=proto_fields, graph=tf.get_default_graph(), session=_sess)) # Saving opponent # For version = 0, or if we reached the numbers of version to switch version (in "staggered" mode) if version_id == 0 or (trainer.flags.mode == 'staggered' and version_id % trainer.flags.staggered_versions == 0): create_opponent_from_player(trainer, version_id) # Saving checkpoint - Once every 10 mins by default if (time.time() - trainer.checkpoint_every) >= trainer.last_checkpoint_time: trainer.last_checkpoint_time = int(time.time()) checkpoint_save_path = os.path.join(trainer.flags.save_dir, 'rl_model.ckpt') trainer.run_func_without_hooks( trainer.session, lambda _sess: trainer.saver.save( _sess, save_path=checkpoint_save_path, global_step=version_id)) # Waiting for gRPC if wait_for_grpc: wait_for_version(trainer, model_name='player', version_id=version_id, timeout=10)
def convert_to_noisy_variables(variables, activation=None): """ Converts a list of variables to noisy variables :param variables: A list of variables to make noisy :param activation: Optional. The activation function to use on the linear noisy transformation :return: Nothing, but modifies the graph in-place Reference: 1706.10295 - Noisy Networks for exploration """ if tf.get_collection(tf.GraphKeys.TRAIN_OP): raise RuntimeError( 'You must call convert_to_noisy_variables before applying an optimizer on the graph.' ) graph = tf.get_default_graph() if not isinstance(variables, list): variables = list(variables) # Replacing each variable for variable in variables: variable_read_op = _get_variable_read_op(variable, graph) variable_outputs = _get_variable_outputs(variable_read_op, graph) variable_scope = variable.name.split(':')[0] variable_shape = variable.shape.as_list() fan_in = variable_shape[0] # Creating noisy variables with tf.variable_scope(variable_scope + '_noisy'): with tf.device(variable.device): s_init = tf.constant_initializer(0.5 / sqrt(fan_in)) noisy_u = tf.identity(variable, name='mu') noisy_s = tf.get_variable( name='sigma', shape=variable.shape, dtype=tf.float32, initializer=s_init, caching_device=variable._caching_device) # pylint: disable=protected-access noise = tf.random.normal(shape=variable_shape) replaced_var = noisy_u + noisy_s * noise replaced_var = activation( replaced_var) if activation else replaced_var # Replacing in-place inputs_index = [ var_index for var_index, var_input in enumerate( graph_editor.sgv(*variable_outputs).inputs) if var_input.name.split(':')[0] == variable_read_op.name.split(':')[0] ] graph_editor.connect( graph_editor.sgv(replaced_var.op), graph_editor.sgv(*variable_outputs).remap_inputs(inputs_index), disconnect_first=True)
def _create_version_step(): """ Creates the version step tensor if it doesn't exist """ from diplomacy_research.utils.tensorflow import tf if BaseAlgorithm._get_version_step() is not None: raise ValueError('"version_step" already exists.') with tf.get_default_graph().name_scope(None): return tf.get_variable( VERSION_STEP, shape=(), dtype=tf.int64, initializer=tf.zeros_initializer(), trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES, VERSION_STEP])
def capture_ops(): """ Decorator function to capture ops creted in the block with capture_ops() as ops: # create some ops print(ops) # => prints ops created. """ op_list = [] scope_name = str(int(time.time() * 10 ** 6)) with tf.name_scope(scope_name): yield op_list graph = tf.get_default_graph() op_list.extend(graph_editor.select_ops(scope_name + '/.*', graph=graph))
def _load_features_placeholders(self): """ Loads the features, outputs, and placeholders nodes from the model """ from diplomacy_research.utils.tensorflow import tf graph = self.graph or tf.get_default_graph() collection_keys = graph.get_all_collection_keys() for key in collection_keys: # If list, getting first element key_value = graph.get_collection(key) if isinstance(key_value, list) and key_value: key_value = key_value[0] # Setting in self. if key.startswith('feature'): self.features[key.replace('feature_', '')] = key_value elif key.startswith('placeholder'): self.placeholders[key.replace('placeholder_', '')] = key_value else: self.outputs[key] = key_value
def initialize(self, session): """ Initialize the adapter (init global vars and the dataset) :type session: tensorflow.python.client.session.Session """ if not self.feedable_dataset.can_support_iterator or not self.iterator: return from diplomacy_research.utils.tensorflow import tf assert session, 'You must pass a session to initialize the adapter' assert isinstance(self.feedable_dataset, QueueDataset), 'The dataset must be a QueueDataset' self.session = session # Initializes uninit global vars graph = self.graph or tf.get_default_graph() if not graph.finalized: with graph.as_default(): var_to_initialize = tf.global_variables() + tf.local_variables( ) is_initialized = self.session.run([ tf.is_variable_initialized(var) for var in var_to_initialize ]) not_initialized_vars = [ var for (var, is_init) in zip(var_to_initialize, is_initialized) if not is_init ] if not_initialized_vars: LOGGER.info('Initialized %d variables.', len(not_initialized_vars)) self.session.run( tf.variables_initializer(not_initialized_vars)) # Initializing the dataset to use the feedable model if not self.feedable_dataset.is_started and self.session: self.feedable_dataset.start(self.session) elif not self.feedable_dataset.is_initialized and self.session: self.feedable_dataset.initialize(self.session)
def save_model(trainer, sess, start_of_epoch=False): """ Saves the current graph to a saved model checkpoint on disk (using a separate thread) :param trainer: A supervised trainer instance. :param sess: The TensorFlow session :param start_of_epoch: Boolean that indicates that we are saving the model at the start of a new epoch. :return: Nothing :type trainer: diplomacy_research.models.training.supervised.trainer.SupervisedTrainer """ from diplomacy_research.utils.tensorflow import tf if not isinstance(sess, tf.Session): trainer.run_func_without_hooks( sess, lambda _sess: save_model(trainer, _sess)) return assert isinstance(sess, tf.Session), 'Session must be a raw TensorFlow session' version_id = int(trainer.progress[0]) + (0 if start_of_epoch else 1) output_dir = os.path.join(trainer.flags.save_dir, 'history') graph = tf.get_default_graph() model_thread = Thread(target=build_saved_model, kwargs={ 'saved_model_dir': output_dir, 'version_id': version_id, 'signature': trainer.signature, 'proto_fields': trainer.dataset_builder.get_proto_fields(), 'graph': graph, 'session': sess, 'history_saver': trainer.history_saver }, daemon=True) model_thread.start()