Exemplo n.º 1
0
    def get_variable_name_dict_for_attention_rnn(self):
        """Constructs a dict mapping the checkpoint variables to those in new graph.

    Returns:
      A dict mapping variable names in the checkpoint to variables in the graph.
    """

        var_dict = dict()
        if self.note_rnn_type != 'attention_rnn':
            return var_dict

        print("Global vars : {0}".format(tf.global_variables()))

        for var in self.variables():
            inner_name = rl_tuner_ops.get_inner_scope(var.name)
            inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name)
            if '/Adam' in var.name:
                # TODO(lukaszkaiser): investigate the problem here and remove this hack.
                pass
            elif 'fully_connected/bias' in var.name:
                var_dict['fully_connected/biases'] = var
            elif 'fully_connected/weights' in var.name:
                var_dict['fully_connected/weights'] = var

        return var_dict
Exemplo n.º 2
0
  def get_variable_name_dict(self):
    """Constructs a dict mapping the checkpoint variables to those in new graph.

    Returns:
      A dict mapping variable names in the checkpoint to variables in the graph.
    """
    var_dict = dict()
    for var in self.variables():
      inner_name = rl_tuner_ops.get_inner_scope(var.name)
      inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name)
      if self.note_rnn_type == 'basic_rnn':
        var_dict[inner_name] = var
      else:
        var_dict[self.checkpoint_scope + '/' + inner_name] = var

    return var_dict
Exemplo n.º 3
0
    def get_variable_name_dict(self):
        """Constructs a dict mapping the checkpoint variables to those in new graph.

    Returns:
      A dict mapping variable names in the checkpoint to variables in the graph.
    """
        var_dict = dict()
        for var in self.variables():
            inner_name = rl_tuner_ops.get_inner_scope(var.name)
            inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name)
            if '/Adam' in var.name:
                # TODO(lukaszkaiser): investigate the problem here and remove this hack.
                pass
            elif self.note_rnn_type == 'basic_rnn':
                var_dict[inner_name] = var
            else:
                var_dict[self.checkpoint_scope + '/' + inner_name] = var

        return var_dict
Exemplo n.º 4
0
  def get_variable_name_dict(self):
    """Constructs a dict mapping the checkpoint variables to those in new graph.

    Returns:
      A dict mapping variable names in the checkpoint to variables in the graph.
    """
    var_dict = dict()
    for var in self.variables():
      inner_name = rl_tuner_ops.get_inner_scope(var.name)
      inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name)
      if '/Adam' in var.name:
        # TODO(lukaszkaiser): investigate the problem here and remove this hack.
        pass
      elif self.note_rnn_type == 'basic_rnn':
        var_dict[inner_name] = var
      else:
        var_dict[self.checkpoint_scope + '/' + inner_name] = var

    return var_dict
Exemplo n.º 5
0
    def get_variable_name_dict(self):
        """Constructs a dict mapping the checkpoint variables to those in new graph.

    Returns:
      A dict mapping variable names in the checkpoint to variables in the graph.
    """
        var_dict = dict()
        for var in self.variables():
            inner_name = rl_tuner_ops.get_inner_scope(var.name)
            inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name)
            if self.note_rnn_type == 'basic_rnn':
                if 'fully_connected' in inner_name and 'bias' in inner_name:
                    # 'fully_connected/bias' has been changed to 'fully_connected/biases'
                    # in newest checkpoints.
                    var_dict[inner_name + 'es'] = var
                else:
                    var_dict[inner_name] = var
            else:
                var_dict[self.checkpoint_scope + '/' + inner_name] = var

        return var_dict