示例#1
0
    def add_meta_information(self, outputs):
        """ Adds features, placeholders, and outputs to the meta-graph
            :param outputs: A dictionary of outputs (e.g. {'training_op': ..., 'training_outputs': ...})
            :return: Nothing, but adds those items to the meta-graph
        """
        from diplomacy_research.utils.tensorflow import tf

        # Storing features
        for feature in self.features:
            cached_features = tf.get_collection('feature_{}'.format(feature))
            if not cached_features:
                tf.add_to_collection('feature_{}'.format(feature),
                                     self.features[feature])

        # Storing placeholders
        for placeholder in self.placeholders:
            cached_placeholders = tf.get_collection(
                'placeholder_{}'.format(placeholder))
            if not cached_placeholders:
                tf.add_to_collection('placeholder_{}'.format(placeholder),
                                     self.placeholders[placeholder])

        # Storing outputs in model
        for output_name in outputs:
            self.outputs[output_name] = outputs[output_name]
        self.outputs['tag/commit_hash'] = GIT_COMMIT_HASH
        self.outputs['is_trainable'] = True
        self.outputs['iterator_resource'] = self.iterator_resource

        # Labeling outputs on meta-graph
        # Clearing collection if we want to re-add the key
        # Avoiding creating key if already present
        for output_tag in self.outputs:
            if output_tag in outputs and tf.get_collection(output_tag):
                tf.get_default_graph().clear_collection(output_tag)
            if self.outputs[output_tag] is not None \
                    and not output_tag.startswith('_') \
                    and not output_tag.endswith('_ta') \
                    and not tf.get_collection(output_tag):
                tf.add_to_collection(output_tag, self.outputs[output_tag])

        # Storing hparams
        for hparam_name, hparam_value in self.hparams.items():
            if not tf.get_collection('tag/hparam/{}'.format(hparam_name)):
                tf.add_to_collection('tag/hparam/{}'.format(hparam_name),
                                     str(hparam_value))
示例#2
0
def create_adapter(trainer):
    """ Creates an adapter (for the learner)
        :param trainer: A reinforcement learning trainer instance.
        :type trainer: diplomacy_research.models.training.reinforcement.trainer.ReinforcementTrainer
    """
    from diplomacy_research.utils.tensorflow import tf
    trainer.adapter = trainer.adapter_constructor(trainer.queue_dataset,
                                                  graph=tf.get_default_graph())
示例#3
0
 def _get_version_step():
     """ Gets the version step tensor """
     from diplomacy_research.utils.tensorflow import tf
     graph = tf.get_default_graph()
     version_step_tensors = graph.get_collection(VERSION_STEP)
     if not version_step_tensors:
         return None
     if len(version_step_tensors) == 1:
         return version_step_tensors[0]
     raise RuntimeError('Multiple version step tensors defined')
示例#4
0
def save_version_model(trainer, version_id=None, wait_for_grpc=False):
    """ Builds a saved model for the given graph for inference
        :param trainer: A reinforcement learning trainer instance.
        :param version_id: Optional. Integer. The version id of the graph to save. Defaults to the current version.
        :param wait_for_grpc: If true, waits for the target to be loaded on the localhost TF Serving before returning.
        :return: Nothing
        :type trainer: diplomacy_research.models.training.reinforcement.trainer.ReinforcementTrainer
    """
    from diplomacy_research.models.training.reinforcement.serving import wait_for_version
    from diplomacy_research.utils.tensorflow import tf

    # Can only save model in standalone mode or while chief in distributed mode
    if trainer.cluster_config and not trainer.cluster_config.is_chief:
        return

    # Creating output directory
    if version_id is None:
        version_id = get_version_id(trainer)
    output_dir = get_serving_directory(trainer, 'player')
    os.makedirs(output_dir, exist_ok=True)

    # Saving model
    proto_fields = trainer.dataset_builder.get_proto_fields()
    trainer.run_func_without_hooks(
        trainer.session,
        lambda _sess: build_saved_model(saved_model_dir=output_dir,
                                        version_id=version_id,
                                        signature=trainer.signature,
                                        proto_fields=proto_fields,
                                        graph=tf.get_default_graph(),
                                        session=_sess))

    # Saving opponent
    # For version = 0, or if we reached the numbers of version to switch version (in "staggered" mode)
    if version_id == 0 or (trainer.flags.mode == 'staggered' and
                           version_id % trainer.flags.staggered_versions == 0):
        create_opponent_from_player(trainer, version_id)

    # Saving checkpoint - Once every 10 mins by default
    if (time.time() -
            trainer.checkpoint_every) >= trainer.last_checkpoint_time:
        trainer.last_checkpoint_time = int(time.time())
        checkpoint_save_path = os.path.join(trainer.flags.save_dir,
                                            'rl_model.ckpt')
        trainer.run_func_without_hooks(
            trainer.session, lambda _sess: trainer.saver.save(
                _sess, save_path=checkpoint_save_path, global_step=version_id))

    # Waiting for gRPC
    if wait_for_grpc:
        wait_for_version(trainer,
                         model_name='player',
                         version_id=version_id,
                         timeout=10)
示例#5
0
def convert_to_noisy_variables(variables, activation=None):
    """ Converts a list of variables to noisy variables
        :param variables: A list of variables to make noisy
        :param activation: Optional. The activation function to use on the linear noisy transformation
        :return: Nothing, but modifies the graph in-place

        Reference: 1706.10295 - Noisy Networks for exploration
    """
    if tf.get_collection(tf.GraphKeys.TRAIN_OP):
        raise RuntimeError(
            'You must call convert_to_noisy_variables before applying an optimizer on the graph.'
        )

    graph = tf.get_default_graph()
    if not isinstance(variables, list):
        variables = list(variables)

    # Replacing each variable
    for variable in variables:
        variable_read_op = _get_variable_read_op(variable, graph)
        variable_outputs = _get_variable_outputs(variable_read_op, graph)
        variable_scope = variable.name.split(':')[0]
        variable_shape = variable.shape.as_list()
        fan_in = variable_shape[0]

        # Creating noisy variables
        with tf.variable_scope(variable_scope + '_noisy'):
            with tf.device(variable.device):
                s_init = tf.constant_initializer(0.5 / sqrt(fan_in))

                noisy_u = tf.identity(variable, name='mu')
                noisy_s = tf.get_variable(
                    name='sigma',
                    shape=variable.shape,
                    dtype=tf.float32,
                    initializer=s_init,
                    caching_device=variable._caching_device)  # pylint: disable=protected-access
                noise = tf.random.normal(shape=variable_shape)

                replaced_var = noisy_u + noisy_s * noise
                replaced_var = activation(
                    replaced_var) if activation else replaced_var

        # Replacing in-place
        inputs_index = [
            var_index for var_index, var_input in enumerate(
                graph_editor.sgv(*variable_outputs).inputs) if
            var_input.name.split(':')[0] == variable_read_op.name.split(':')[0]
        ]
        graph_editor.connect(
            graph_editor.sgv(replaced_var.op),
            graph_editor.sgv(*variable_outputs).remap_inputs(inputs_index),
            disconnect_first=True)
示例#6
0
 def _create_version_step():
     """ Creates the version step tensor if it doesn't exist """
     from diplomacy_research.utils.tensorflow import tf
     if BaseAlgorithm._get_version_step() is not None:
         raise ValueError('"version_step" already exists.')
     with tf.get_default_graph().name_scope(None):
         return tf.get_variable(
             VERSION_STEP,
             shape=(),
             dtype=tf.int64,
             initializer=tf.zeros_initializer(),
             trainable=False,
             collections=[tf.GraphKeys.GLOBAL_VARIABLES, VERSION_STEP])
def capture_ops():
    """ Decorator function to capture ops creted in the block

        with capture_ops() as ops:
            # create some ops
        print(ops) # => prints ops created.
    """
    op_list = []
    scope_name = str(int(time.time() * 10 ** 6))
    with tf.name_scope(scope_name):
        yield op_list

    graph = tf.get_default_graph()
    op_list.extend(graph_editor.select_ops(scope_name + '/.*', graph=graph))
示例#8
0
    def _load_features_placeholders(self):
        """ Loads the features, outputs, and placeholders nodes from the model """
        from diplomacy_research.utils.tensorflow import tf
        graph = self.graph or tf.get_default_graph()
        collection_keys = graph.get_all_collection_keys()

        for key in collection_keys:
            # If list, getting first element
            key_value = graph.get_collection(key)
            if isinstance(key_value, list) and key_value:
                key_value = key_value[0]

            # Setting in self.
            if key.startswith('feature'):
                self.features[key.replace('feature_', '')] = key_value
            elif key.startswith('placeholder'):
                self.placeholders[key.replace('placeholder_', '')] = key_value
            else:
                self.outputs[key] = key_value
示例#9
0
    def initialize(self, session):
        """ Initialize the adapter (init global vars and the dataset)
            :type session: tensorflow.python.client.session.Session
        """
        if not self.feedable_dataset.can_support_iterator or not self.iterator:
            return

        from diplomacy_research.utils.tensorflow import tf
        assert session, 'You must pass a session to initialize the adapter'
        assert isinstance(self.feedable_dataset,
                          QueueDataset), 'The dataset must be a QueueDataset'
        self.session = session

        # Initializes uninit global vars
        graph = self.graph or tf.get_default_graph()
        if not graph.finalized:
            with graph.as_default():
                var_to_initialize = tf.global_variables() + tf.local_variables(
                )
                is_initialized = self.session.run([
                    tf.is_variable_initialized(var)
                    for var in var_to_initialize
                ])
                not_initialized_vars = [
                    var for (var,
                             is_init) in zip(var_to_initialize, is_initialized)
                    if not is_init
                ]
                if not_initialized_vars:
                    LOGGER.info('Initialized %d variables.',
                                len(not_initialized_vars))
                    self.session.run(
                        tf.variables_initializer(not_initialized_vars))

        # Initializing the dataset to use the feedable model
        if not self.feedable_dataset.is_started and self.session:
            self.feedable_dataset.start(self.session)
        elif not self.feedable_dataset.is_initialized and self.session:
            self.feedable_dataset.initialize(self.session)
示例#10
0
def save_model(trainer, sess, start_of_epoch=False):
    """ Saves the current graph to a saved model checkpoint on disk (using a separate thread)
        :param trainer: A supervised trainer instance.
        :param sess: The TensorFlow session
        :param start_of_epoch: Boolean that indicates that we are saving the model at the start of a new epoch.
        :return: Nothing
        :type trainer: diplomacy_research.models.training.supervised.trainer.SupervisedTrainer
    """
    from diplomacy_research.utils.tensorflow import tf
    if not isinstance(sess, tf.Session):
        trainer.run_func_without_hooks(
            sess, lambda _sess: save_model(trainer, _sess))
        return

    assert isinstance(sess,
                      tf.Session), 'Session must be a raw TensorFlow session'
    version_id = int(trainer.progress[0]) + (0 if start_of_epoch else 1)
    output_dir = os.path.join(trainer.flags.save_dir, 'history')
    graph = tf.get_default_graph()
    model_thread = Thread(target=build_saved_model,
                          kwargs={
                              'saved_model_dir':
                              output_dir,
                              'version_id':
                              version_id,
                              'signature':
                              trainer.signature,
                              'proto_fields':
                              trainer.dataset_builder.get_proto_fields(),
                              'graph':
                              graph,
                              'session':
                              sess,
                              'history_saver':
                              trainer.history_saver
                          },
                          daemon=True)
    model_thread.start()