Exemple #1
0
def restore(sess, global_vars):
    print('from vgg_16 pretrained model')

    reader = NewCheckpointReader(os.path.join(os.getcwd(),
                                              'model/vgg_16.ckpt'))

    # no batchnorm from vgg_16 pretrained model
    restored_var_names = [
        name + ':0' for name in reader.get_variable_to_dtype_map().keys()
        if re.match('^.*weights$', name)
    ]  # skip conv's biases

    restored_vars = [
        var for var in global_vars if var.name in restored_var_names
    ]

    restored_var_names = [var.name[:-2] for var in restored_vars]

    value_ph = tf.placeholder(dtype=tf.float32)

    for i in range(len(restored_var_names)):
        sess.run(
            tf.assign(restored_vars[i], value_ph),
            feed_dict={value_ph: reader.get_tensor(restored_var_names[i])})

    initialized_vars = [var for var in global_vars if not var in restored_vars]

    sess.run(tf.variables_initializer(initialized_vars))
Exemple #2
0
def restore(sess, global_vars):
    print('from resnet_v2_50 pretrained model')

    reader = NewCheckpointReader(os.path.join(
        os.getcwd(), 'model/resnet_v2_50.ckpt'))

    # restore both weights and biases from conv and shortcut layers
    restored_var_names = [name + ':0'
                          for name in reader.get_variable_to_dtype_map().keys()
                          if re.match('^.*weights$', name) or re.match('^.*biases$', name)]

    restored_vars = [var for var in global_vars
                     if var.name in restored_var_names]

    restored_var_names = [var.name[:-2] for var in restored_vars]

    value_ph = tf.placeholder(dtype=tf.float32)

    for i in range(len(restored_var_names)):
        sess.run(tf.assign(restored_vars[i], value_ph),
                 feed_dict={value_ph: reader.get_tensor(restored_var_names[i])})

    initialized_vars = [var for var in global_vars
                        if not var in restored_vars]

    sess.run(tf.variables_initializer(initialized_vars))
Exemple #3
0
    def load_ckpt(self, pretrained):
        # restore model with ckpt/pretrain or init
        try:
            print('trying to restore last checkpoint')
            last_ckpt_path = tf.train.latest_checkpoint(
                checkpoint_dir=cfg.ckpt_dir)
            self.saver.restore(self.sess, save_path=last_ckpt_path)
            print('restored checkpoint from:', last_ckpt_path)
        except:
            if self.is_training:
                print('init variables')
                restored_vars = []
                global_vars = tf.global_variables()

                if pretrained:  # restore from tf-slim model
                    if os.path.exists(
                            os.path.join(cfg.workspace, premodel['path'])):
                        print('from ' + premodel['endp'])

                        import re
                        from tensorflow.python.pywrap_tensorflow import NewCheckpointReader

                        reader = NewCheckpointReader(
                            os.path.join(cfg.workspace, premodel['ckpt']))

                        # only restoring conv's weights
                        restored_var_names = [
                            name + ':0' for name in
                            reader.get_variable_to_dtype_map().keys()
                            if re.match(premodel['rptn'], name)
                        ]

                        # update restored variables from pretrained model
                        restored_vars = [
                            var for var in global_vars
                            if var.name in restored_var_names
                        ]

                        # update restored variables' name
                        restored_var_names = [
                            var.name[:-2] for var in restored_vars
                        ]

                        # assignment variables
                        value_ph = tf.placeholder(tf.float32, shape=None)
                        for i in range(len(restored_var_names)):
                            self.sess.run(
                                tf.assign(restored_vars[i], value_ph),
                                feed_dict={
                                    value_ph:
                                    reader.get_tensor(restored_var_names[i])
                                })

                initialized_vars = list(set(global_vars) - set(restored_vars))
                self.sess.run(tf.variables_initializer(initialized_vars))
Exemple #4
0
def get_list_of_variables_from_ckpt(ckpt_file):
    reader = NewCheckpointReader(ckpt_file)
    names = reader.get_variable_to_dtype_map()
    return list(names.keys())