예제 #1
0
def restore_darknet19_variables(sess, imdb, net_name='darknet19', save_epoch=True):
    """Initialize or restore the varialbes in darknet19."""

    sfiles = get_ordered_ckpts(sess, imdb, net_name, save_epoch=save_epoch)
    lsf = len(sfiles)

    if lsf == 0:
        # TODO: this is not used because the current weight has deprecated scope names. Need to fix this
        # if os.path.isfile(cfg.darknet_imagenet_weight_path + ".meta"):
        #     # Try to load from weight file first
        #     print 'Restorining model from weight file {:s}'.format(cfg.darknet_imagenet_weight_path)
        #     saver = tf.train.Saver()
        #     saver.restore(sess, cfg.darknet_pascal_weight_path)
        #     print 'Restored.'

        # loading from imagenet trained model
        # currently assume trined previously on imagenet
        LOAD_CKPTS_DIR = cfg.get_ckpts_dir('darknet19', 'ilsvrc_2017_cls')
        ckpt_vars = [t[0] for t in tf.contrib.framework.list_variables(LOAD_CKPTS_DIR)]
        vars_to_restore = []
        vars_to_init = []
        for v in tf.global_variables():
            if v.name[:-2] in ckpt_vars:
                vars_to_restore.append(v)
            else:
                vars_to_init.append(v)
        # print 'vars_to_restore:', len(vars_to_restore)
        # print 'vars_to_init:', len(vars_to_init)

        init_op = tf.variables_initializer(vars_to_init)
        saver = tf.train.Saver(vars_to_restore)

        print 'Initializing new variables to train from imagenet trained model'
        sess.run(init_op)
        imagnet_sfiles = get_ordered_ckpts_by_dbname(sess, 'ilsvrc_2017_cls', net_name, save_epoch=True)
        print 'Restorining model snapshots from {:s}'.format(imagnet_sfiles[-1])
        saver.restore(sess, str(imagnet_sfiles[-1]))
        return 0

    else:
        print 'Restorining model snapshots from {:s}'.format(sfiles[-1])
        saver = tf.train.Saver()
        saver.restore(sess, str(sfiles[-1]))
        print 'Restored.'

        fnames = sfiles[-1].split('_')
        return int(fnames[-1][:-5])
예제 #2
0
def get_ordered_ckpts_by_dbname(sess, imdb_name, net_name, save_epoch=True):
    """Get the ckpts for particular network on certain dataset (by name).
    The ckpts is ordered in ascending order of time.

    Returns: sorted list of ckpt names.
    """

    # Find previous snapshots if there is any to restore from
    ckpts_dir = cfg.get_ckpts_dir(net_name, imdb_name)
    if save_epoch:
        save_interval = 'epoch'
    else:
        save_interval = 'iter'
    sfiles = os.path.join(ckpts_dir,
                          cfg.TRAIN_SNAPSHOT_PREFIX + '_' + save_interval + '_*.ckpt.meta')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    sfiles = [ss.replace('.meta', '') for ss in sfiles]

    return sfiles
예제 #3
0
        queue_in.get()
        images, labels = imdb.get()
        queue_out.put([images, labels])


imdb = ilsvrc_cls('train', data_aug=True, multithread=cfg.MULTITHREAD)
val_imdb = ilsvrc_cls('val', batch_size=64)
# set up child process for getting validation data
queue_in = Queue()
queue_out = Queue()
val_data_process = Process(target=get_validation_process,
                           args=(val_imdb, queue_in, queue_out))
val_data_process.start()
queue_in.put(True)  # start getting the first batch

CKPTS_DIR = cfg.get_ckpts_dir('darknet19', imdb.name)
TENSORBOARD_TRAIN_DIR, TENSORBOARD_VAL_DIR = cfg.get_output_tb_dir(
    'darknet19', imdb.name)


input_data = tf.placeholder(tf.float32, [None, 224, 224, 3])
label_data = tf.placeholder(tf.int32, None)
is_training = tf.placeholder(tf.bool)

logits = darknet19(input_data, is_training=is_training)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
    labels=label_data, logits=logits)
loss = tf.reduce_mean(loss)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
예제 #4
0
# NOTE: check fix the data imread and label
imdb = ilsvrc_cls('train',
                  multithread=cfg.MULTITHREAD,
                  batch_size=TRAIN_BATCH_SIZE,
                  image_size=299,
                  random_noise=True)
val_imdb = ilsvrc_cls('val', batch_size=18, image_size=299, random_noise=True)
# set up child process for getting validation data
queue_in = Queue()
queue_out = Queue()
val_data_process = Process(target=get_validation_process,
                           args=(val_imdb, queue_in, queue_out))
val_data_process.start()
queue_in.put(True)  # start getting the first batch

CKPTS_DIR = cfg.get_ckpts_dir('inception_resnet', imdb.name)
TENSORBOARD_TRAIN_DIR, TENSORBOARD_VAL_DIR = cfg.get_output_tb_dir(
    'inception_resnet', imdb.name)
TENSORBOARD_TRAIN_ADV_DIR = os.path.abspath(
    os.path.join(cfg.ROOT_DIR, 'tensorboard', 'inception_resnet', imdb.name,
                 'train_adv'))
if not os.path.exists(TENSORBOARD_TRAIN_ADV_DIR):
    os.makedirs(TENSORBOARD_TRAIN_ADV_DIR)
TENSORBOARD_VAL_ADV_DIR = os.path.abspath(
    os.path.join(cfg.ROOT_DIR, 'tensorboard', 'inception_resnet', imdb.name,
                 'val_adv'))
if not os.path.exists(TENSORBOARD_VAL_ADV_DIR):
    os.makedirs(TENSORBOARD_VAL_ADV_DIR)

g_inception_resnet = tf.Graph()
with g_inception_resnet.as_default():
from utils.timer import Timer
from yolo2_nets.net_utils import restore_resnet_tf_variables, show_yolo_detection
from yolo2_nets.tf_resnet import resnet_v1_50

slim = tf.contrib.slim

# TODO: make the image path to be user input
image_path = '/home/wenxi/Projects/tensorflow_yolo2/experiments/fig1.jpg'

IMAGE_SIZE = cfg.IMAGE_SIZE
S = cfg.S
B = cfg.B
# create database instance
imdb = pascal_voc('trainval')
NUM_CLASS = imdb.num_class
CKPTS_DIR = cfg.get_ckpts_dir('resnet50', imdb.name)

input_data = tf.placeholder(tf.float32, [None, 224, 224, 3])

# read in the test image
image = cv2.imread(image_path)
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = image.astype(np.float32)
image = (image / 255.0) * 2.0 - 1.0
image = image.reshape((1, 224, 224, 3))

# get the right arg_scope in order to load weights
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
    # net is shape [batch_size, S, S, 2048] if input size is 244 x 244
    net, end_points = resnet_v1_50(input_data, is_training=False)