Example #1
0
def restore(sess):
    """choose which param to restore"""
    if FLAGS.pretrained_model:
        if tf.gfile.IsDirectory(FLAGS.pretrained_model):
            checkpoint_path = tf.train.latest_checkpoint(
                FLAGS.pretrained_model)
        else:
            checkpoint_path = FLAGS.pretrained_model

        if FLAGS.checkpoint_exclude_scopes is None:
            FLAGS.checkpoint_exclude_scopes = 'pyramid'
        if FLAGS.checkpoint_include_scopes is None:
            FLAGS.checkpoint_include_scopes = 'resnet_v1_50'

        vars_to_restore = get_var_list_to_restore()
        for var in vars_to_restore:
            print('restoring ', var.name)

        try:
            restorer = tf.train.Saver(vars_to_restore)
            restorer.restore(sess, checkpoint_path)
            print('Restored %d(%d) vars from %s' %
                  (len(vars_to_restore), len(
                      tf.global_variables()), checkpoint_path))
        except:
            print('Checking your params %s' % (checkpoint_path))
            raise
Example #2
0
def restore(sess):
     """choose which param to restore"""
     if FLAGS.restore_previous_if_exists:
        try:
            print (FLAGS.train_dir)
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            reader = tf.train.NewCheckpointReader(checkpoint_path)
            saved_shapes = reader.get_variable_to_shape_map()
            var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()if var.name.split(':')[0] in saved_shapes])
            restore_vars = []
            name2var = dict(zip(map(lambda x:x.name.split(':')[0], tf.global_variables()), tf.global_variables()))
            with tf.variable_scope('', reuse=True):
                for var_name, saved_var_name in var_names:
                    curr_var = name2var[saved_var_name]
                    var_shape = curr_var.get_shape().as_list()
                    if var_shape == saved_shapes[saved_var_name]:
                        restore_vars.append(curr_var)
            restorer = tf.train.Saver(restore_vars)
            restorer.restore(sess, checkpoint_path)
            print ('restored previous model %s from %s'\
                    %(checkpoint_path, FLAGS.train_dir))
            time.sleep(2)
            return
        except:
            print ('--restore_previous_if_exists is set, but failed to restore in %s %s'\
                    % (FLAGS.train_dir, checkpoint_path))
            time.sleep(2)

     if FLAGS.pretrained_model:
        if tf.gfile.IsDirectory(FLAGS.pretrained_model):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.pretrained_model)
        else:
            checkpoint_path = FLAGS.pretrained_model

        if FLAGS.checkpoint_exclude_scopes is None:
            FLAGS.checkpoint_exclude_scopes='pyramid'
        if FLAGS.checkpoint_include_scopes is None:
            FLAGS.checkpoint_include_scopes='resnet_v1_50'

        vars_to_restore = get_var_list_to_restore()
        for var in vars_to_restore:
            print ('restoring ', var.name)
      
        try:
           restorer = tf.train.Saver(vars_to_restore)
           restorer.restore(sess, checkpoint_path)
           print ('Restored %d(%d) vars from %s' %(
               len(vars_to_restore), len(tf.global_variables()),
               checkpoint_path ))
        except:
           print ('Checking your params %s' %(checkpoint_path))
           raise
Example #3
0
def restore(sess):
    """choose which param to restore"""
    if FLAGS.restore_previous_if_exists:
        try:
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            ###########
            restorer = tf.train.Saver()

            restorer.restore(sess, checkpoint_path)
            print ('restored previous model %s from %s'\
                    %(checkpoint_path, FLAGS.train_dir))
            time.sleep(2)
            return
        except:
            print ('--restore_previous_if_exists is set, but failed to restore in %s %s'\
                    % (FLAGS.train_dir, checkpoint_path))
            time.sleep(2)

    if FLAGS.pretrained_model:
        if tf.gfile.IsDirectory(FLAGS.pretrained_model):
            checkpoint_path = tf.train.latest_checkpoint(
                FLAGS.pretrained_model)
        else:
            checkpoint_path = FLAGS.pretrained_model

        if FLAGS.checkpoint_exclude_scopes is None:
            FLAGS.checkpoint_exclude_scopes = 'pyramid'
        if FLAGS.checkpoint_include_scopes is None:
            FLAGS.checkpoint_include_scopes = 'resnet_v1_50'

        vars_to_restore = get_var_list_to_restore()
        for var in vars_to_restore:
            print('restoring ', var.name)

        try:
            restorer = tf.train.Saver(vars_to_restore)
            restorer.restore(sess, checkpoint_path)
            print('Restored %d(%d) vars from %s' %
                  (len(vars_to_restore), len(
                      tf.global_variables()), checkpoint_path))
        except:
            print('Checking your params %s' % (checkpoint_path))
            raise
Example #4
0
def restore(sess):
     """choose which param to restore"""
     if FLAGS.restore_previous_if_exists:
        try:
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            restorer = tf.train.Saver()
            restorer.restore(sess, checkpoint_path)
            print ('restored previous model %s from %s'\
                    %(checkpoint_path, FLAGS.train_dir))
            time.sleep(2)
            return
        except:
            print ('--restore_previous_if_exists is set, but failed to restore in %s %s'\
                    % (FLAGS.train_dir, checkpoint_path))
            time.sleep(2)

     if FLAGS.pretrained_model:
        if tf.gfile.IsDirectory(FLAGS.pretrained_model):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.pretrained_model)
        else:
            checkpoint_path = FLAGS.pretrained_model

        if FLAGS.checkpoint_exclude_scopes is None:
            FLAGS.checkpoint_exclude_scopes='pyramid'
        if FLAGS.checkpoint_include_scopes is None:
            FLAGS.checkpoint_include_scopes='resnet_v1_50'

        vars_to_restore = get_var_list_to_restore()
        for var in vars_to_restore:
            print ('restoring ', var.name)
      
        try:
           restorer = tf.train.Saver(vars_to_restore)
           restorer.restore(sess, checkpoint_path)
           print ('Restored %d(%d) vars from %s' %(
               len(vars_to_restore), len(tf.global_variables()),
               checkpoint_path ))
        except:
           print ('Checking your params %s' %(checkpoint_path))
           raise
Example #5
0
        sess.run(init_op)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(sess=sess, coord=coord)

        ## restore pretrained model
        # FLAGS.pretrained_model = None
        if FLAGS.pretrained_model:
            if tf.gfile.IsDirectory(FLAGS.pretrained_model):
                checkpoint_path = tf.train.latest_checkpoint(
                    FLAGS.pretrained_model)
            else:
                checkpoint_path = FLAGS.pretrained_model
            FLAGS.checkpoint_exclude_scopes = 'pyramid'
            FLAGS.checkpoint_include_scopes = 'resnet_v1_50'
            vars_to_restore = get_var_list_to_restore()
            for var in vars_to_restore:
                print('restoring ', var.name)

            try:
                restorer = tf.train.Saver(vars_to_restore)
                restorer.restore(sess, checkpoint_path)
                print('Restored %d(%d) vars from %s' %
                      (len(vars_to_restore), len(
                          tf.global_variables()), checkpoint_path))
            except:
                print('Checking your params %s' % (checkpoint_path))
                raise

        # import libs.memory_util as memory_util
        # memory_util.vlog(1)
def restore(sess):
    """choose which param to restore"""
    if FLAGS.restore_previous_if_exists:
        try:
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            ###########
            restorer = tf.train.Saver()
            ###########

            ###########
            # not_restore = [ 'pyramid/fully_connected/weights:0',
            #                 'pyramid/fully_connected/biases:0',
            #                 'pyramid/fully_connected/weights:0',
            #                 'pyramid/fully_connected_1/biases:0',
            #                 'pyramid/fully_connected_1/weights:0',
            #                 'pyramid/fully_connected_2/weights:0',
            #                 'pyramid/fully_connected_2/biases:0',
            #                 'pyramid/fully_connected_3/weights:0',
            #                 'pyramid/fully_connected_3/biases:0',
            #                 'pyramid/Conv/weights:0',
            #                 'pyramid/Conv/biases:0',
            #                 'pyramid/Conv_1/weights:0',
            #                 'pyramid/Conv_1/biases:0',
            #                 'pyramid/Conv_2/weights:0',
            #                 'pyramid/Conv_2/biases:0',
            #                 'pyramid/Conv_3/weights:0',
            #                 'pyramid/Conv_3/biases:0',
            #                 'pyramid/Conv2d_transpose/weights:0',
            #                 'pyramid/Conv2d_transpose/biases:0',
            #                 'pyramid/Conv_4/weights:0',
            #                 'pyramid/Conv_4/biases:0',
            #                 'pyramid/fully_connected/weights/Momentum:0',
            #                 'pyramid/fully_connected/biases/Momentum:0',
            #                 'pyramid/fully_connected/weights/Momentum:0',
            #                 'pyramid/fully_connected_1/biases/Momentum:0',
            #                 'pyramid/fully_connected_1/weights/Momentum:0',
            #                 'pyramid/fully_connected_2/weights/Momentum:0',
            #                 'pyramid/fully_connected_2/biases/Momentum:0',
            #                 'pyramid/fully_connected_3/weights/Momentum:0',
            #                 'pyramid/fully_connected_3/biases/Momentum:0',
            #                 'pyramid/Conv/weights/Momentum:0',
            #                 'pyramid/Conv/biases/Momentum:0',
            #                 'pyramid/Conv_1/weights/Momentum:0',
            #                 'pyramid/Conv_1/biases/Momentum:0',
            #                 'pyramid/Conv_2/weights/Momentum:0',
            #                 'pyramid/Conv_2/biases/Momentum:0',
            #                 'pyramid/Conv_3/weights/Momentum:0',
            #                 'pyramid/Conv_3/biases/Momentum:0',
            #                 'pyramid/Conv2d_transpose/weights/Momentum:0',
            #                 'pyramid/Conv2d_transpose/biases/Momentum:0',
            #                 'pyramid/Conv_4/weights/Momentum:0',
            #                 'pyramid/Conv_4/biases/Momentum:0',]
            # vars_to_restore = [v for v in  tf.all_variables()if v.name not in not_restore]
            # restorer = tf.train.Saver(vars_to_restore)
            # for var in vars_to_restore:
            #     print ('restoring ', var.name)
            ############

            restorer.restore(sess, checkpoint_path)
            print ('restored previous model %s from %s'\
                    %(checkpoint_path, FLAGS.train_dir))
            time.sleep(2)
            return
        except:
            print ('--restore_previous_if_exists is set, but failed to restore in %s %s'\
                    % (FLAGS.train_dir, checkpoint_path))
            time.sleep(2)

    if FLAGS.pretrained_model:
        if tf.gfile.IsDirectory(FLAGS.pretrained_model):
            checkpoint_path = tf.train.latest_checkpoint(
                FLAGS.pretrained_model)
        else:
            checkpoint_path = FLAGS.pretrained_model

        if FLAGS.checkpoint_exclude_scopes is None:
            FLAGS.checkpoint_exclude_scopes = 'pyramid'
        if FLAGS.checkpoint_include_scopes is None:
            FLAGS.checkpoint_include_scopes = 'resnet_v1_50'

        vars_to_restore = get_var_list_to_restore()
        for var in vars_to_restore:
            print('restoring ', var.name)

        try:
            restorer = tf.train.Saver(vars_to_restore)
            restorer.restore(sess, checkpoint_path)
            print('Restored %d(%d) vars from %s' %
                  (len(vars_to_restore), len(
                      tf.global_variables()), checkpoint_path))
        except:
            print('Checking your params %s' % (checkpoint_path))
            raise