def load(self, data_path, session, ignore_missing=False): ''' Load network weights. data_path: The path to the numpy-serialized network weights session: The current TensorFlow session ignore_missing: If true, serialized weights for missing layers are ignored. ''' data_dict = np.load(data_path, encoding='bytes').item() for op_name, param_dict in data_dict.items(): if isinstance(data_dict[op_name], np.ndarray): if 'RMSProp' in op_name: continue with tf.variable_scope('', reuse=True): var = tf.get_variable(op_name.replace(':0', '')) try: session.run(var.assign(data_dict[op_name])) except Exception as e: print(op_name) print(e) sys.exit(-1) else: op_name = to_str(op_name) # if op_name > 'conv4': # print(op_name, 'skipped') # continue # print(op_name, 'restored') with tf.variable_scope(op_name, reuse=True): for param_name, data in param_dict.items(): try: var = tf.get_variable(to_str(param_name)) session.run(var.assign(data)) except ValueError as e: print(e) if not ignore_missing: raise
def load_from_numpy(sess, net, data_path, logger = None, ignore_missing=False): ''' Load network weights from numpy. data_path: The path to the numpy-serialized network weights session: The current TensorFlow session ignore_missing: If true, serialized weights for missing layers are ignored. ''' data_dict = np.load(data_path, allow_pickle=True, encoding='bytes').item() for op_name, param_dict in data_dict.items(): if isinstance(data_dict[op_name], np.ndarray): if op_name not in net.restorable_variables(): continue with tf.variable_scope('', reuse=True): var = tf.get_variable(op_name.replace(':0', '')) try: sess.run(var.assign(data_dict[op_name])) except Exception as ex: if logger: logger.error(op_name) logger.error(ex) else: print(op_name) print(ex) sys.exit(-1) else: op_name = to_str(op_name) with tf.variable_scope(op_name, reuse=True): for param_name, data in param_dict.items(): try: var = tf.get_variable(to_str(param_name)) sess.run(var.assign(data)) except ValueError as ex: if logger: logger.error(param_name) logger.error(ex) else: print(param_name) print(ex) if not ignore_missing: raise