def get_variable_name_dict_for_attention_rnn(self): """Constructs a dict mapping the checkpoint variables to those in new graph. Returns: A dict mapping variable names in the checkpoint to variables in the graph. """ var_dict = dict() if self.note_rnn_type != 'attention_rnn': return var_dict print("Global vars : {0}".format(tf.global_variables())) for var in self.variables(): inner_name = rl_tuner_ops.get_inner_scope(var.name) inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name) if '/Adam' in var.name: # TODO(lukaszkaiser): investigate the problem here and remove this hack. pass elif 'fully_connected/bias' in var.name: var_dict['fully_connected/biases'] = var elif 'fully_connected/weights' in var.name: var_dict['fully_connected/weights'] = var return var_dict
def get_variable_name_dict(self): """Constructs a dict mapping the checkpoint variables to those in new graph. Returns: A dict mapping variable names in the checkpoint to variables in the graph. """ var_dict = dict() for var in self.variables(): inner_name = rl_tuner_ops.get_inner_scope(var.name) inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name) if self.note_rnn_type == 'basic_rnn': var_dict[inner_name] = var else: var_dict[self.checkpoint_scope + '/' + inner_name] = var return var_dict
def get_variable_name_dict(self): """Constructs a dict mapping the checkpoint variables to those in new graph. Returns: A dict mapping variable names in the checkpoint to variables in the graph. """ var_dict = dict() for var in self.variables(): inner_name = rl_tuner_ops.get_inner_scope(var.name) inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name) if '/Adam' in var.name: # TODO(lukaszkaiser): investigate the problem here and remove this hack. pass elif self.note_rnn_type == 'basic_rnn': var_dict[inner_name] = var else: var_dict[self.checkpoint_scope + '/' + inner_name] = var return var_dict
def get_variable_name_dict(self): """Constructs a dict mapping the checkpoint variables to those in new graph. Returns: A dict mapping variable names in the checkpoint to variables in the graph. """ var_dict = dict() for var in self.variables(): inner_name = rl_tuner_ops.get_inner_scope(var.name) inner_name = rl_tuner_ops.trim_variable_postfixes(inner_name) if self.note_rnn_type == 'basic_rnn': if 'fully_connected' in inner_name and 'bias' in inner_name: # 'fully_connected/bias' has been changed to 'fully_connected/biases' # in newest checkpoints. var_dict[inner_name + 'es'] = var else: var_dict[inner_name] = var else: var_dict[self.checkpoint_scope + '/' + inner_name] = var return var_dict