def get_model_params(self,name_list,sess=None): sess=assert_session(sess) if isinstance(name_list,str): name_list=[name_list] elif not isinstance(name_list,(list,tuple)): raise TypeError('name_list must be str, list or tuple') value_list=sess.run([self.tf_variables[v] for v in name_list]) return dict(zip(name_list,value_list))
def assign_to_variables(tf_assign_ops,tf_assign_phs,value_dict,sess=None): sess=assert_session(sess) assign_ops=[] feed_dict={} for k,v in value_dict.items(): assign_ops.append(tf_assign_ops[k]) feed_dict[tf_assign_phs[k]]=v sess.run(assign_ops,feed_dict=feed_dict)
def compute_fisher_information(self,X_train,sess=None): #may be only applicable when the last layer goes train with softmax function sess=assert_session(sess) rv_grad_keys=list(self.rv_grads.keys()) rv_grad_tensors=[self.rv_grads[k] for k in rv_grad_keys] F_accum={} for v in rv_grad_keys: F_accum[v]=np.zeros(self.rv_grads[v].shape.as_list()) for i in range(len(X_train)): rv_grads_val=sess.run(rv_grad_tensors,feed_dict={self.tf_tensors['input']:X_train[i:i+1]}) for j,v in enumerate(rv_grad_keys): F_accum[v]+=rv_grads_val[j]**2 for k in F_accum.keys(): F_accum[k]/=len(X_train) return F_accum
def reset_ewc_variables(self,sess=None): sess=assert_session(sess) sess.run(tf.variables_initializer(list(self.fisher_variables.values())+list(self.prev_variables.values())))
def get_all_fisher_variables(self,sess=None): sess=assert_session(sess) keys_list=list(self.fisher_variables.keys()) value_list=sess.run([self.fisher_variables[v] for v in keys_list]) return dict(zip(keys_list,value_list))
def set_fisher_variables(self,fisher_var,sess=None): sess=assert_session(sess) assign_to_variables(self.fisher_variables_ops,self.fisher_variables_phs,fisher_var,sess)
def get_all_regularizable_variables(self,sess=None): sess=assert_session(sess) return self.get_model_params(list(self.regularizable_variables.keys()),sess)
def set_prev_variables(self,prev_rv,sess=None): sess=assert_session(sess) assign_to_variables(self.prev_variables_ops,self.prev_variables_phs,prev_rv,sess)
def load_model(self,file_name,sess=None): sess=assert_session(sess) model_params=pickle.load(open(file_name,'rb')) self.set_model_params(model_params,sess)
def save_model(self,file_name,sess=None): sess=assert_session(sess) model_params=self.get_all_model_params(sess) pickle.dump(model_params,open(file_name,'wb'))
def set_model_params(self,value_dict,sess=None): sess=assert_session(sess) assign_to_variables(self.tf_variables_ops,self.tf_variables_phs,value_dict,sess)