Exemplo n.º 1
0
 def load(self, data_path, session, ignore_missing=False):
     '''
     Load network weights.
     data_path: The path to the numpy-serialized network weights
     session: The current TensorFlow session
     ignore_missing: If true, serialized weights for missing layers are ignored.
     '''
     data_dict = np.load(data_path, encoding='bytes').item()
     for op_name, param_dict in data_dict.items():
         if isinstance(data_dict[op_name], np.ndarray):
             if 'RMSProp' in op_name:
                 continue
             with tf.variable_scope('', reuse=True):
                 var = tf.get_variable(op_name.replace(':0', ''))
                 try:
                     session.run(var.assign(data_dict[op_name]))
                 except Exception as e:
                     print(op_name)
                     print(e)
                     sys.exit(-1)
         else:
             op_name = to_str(op_name)
             # if op_name > 'conv4':
             #     print(op_name, 'skipped')
             #     continue
             # print(op_name, 'restored')
             with tf.variable_scope(op_name, reuse=True):
                 for param_name, data in param_dict.items():
                     try:
                         var = tf.get_variable(to_str(param_name))
                         session.run(var.assign(data))
                     except ValueError as e:
                         print(e)
                         if not ignore_missing:
                             raise
Exemplo n.º 2
0
def load_from_numpy(sess, net, data_path, logger = None, ignore_missing=False):
        '''
        Load network weights from numpy.
        data_path: The path to the numpy-serialized network weights
        session: The current TensorFlow session
        ignore_missing: If true, serialized weights for missing layers are ignored.
        '''
        data_dict = np.load(data_path, allow_pickle=True, encoding='bytes').item()
        for op_name, param_dict in data_dict.items():
            if isinstance(data_dict[op_name], np.ndarray):
                if op_name not in net.restorable_variables():
                    continue
                with tf.variable_scope('', reuse=True):
                    var = tf.get_variable(op_name.replace(':0', ''))
                    try:
                        sess.run(var.assign(data_dict[op_name]))
                    except Exception as ex:
                        if logger:
                            logger.error(op_name)
                            logger.error(ex)
                        else:
                            print(op_name)
                            print(ex)
                        sys.exit(-1)
            else:
                op_name = to_str(op_name)
                with tf.variable_scope(op_name, reuse=True):
                    for param_name, data in param_dict.items():
                        try:
                            var = tf.get_variable(to_str(param_name))
                            sess.run(var.assign(data))
                        except ValueError as ex:
                            if logger:
                                logger.error(param_name)
                                logger.error(ex)
                            else:
                                print(param_name)
                                print(ex)
                            if not ignore_missing:
                                raise