Beispiel #1
0
def _main(name):
    import tensorflow as tf
    starter = get_starter(name)
    continue_on_error = FLAGS.continue_on_error
    if starter.has_checkpoint:
        try:
            graph = tf.Graph()
            with graph.as_default():
                image = tf.placeholder(shape=(None, 311, 311, 3),
                                       dtype=tf.float32)
                fn = starter.get_network_fn(is_training=True)
                fn(image)
                var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
                saver = tf.train.Saver(var_list=var_list)

            with tf.Session(graph=graph) as sess:
                ckpt = starter.get_checkpoint()
                saver.restore(sess, ckpt)
                return True
        except Exception:
            if continue_on_error:
                return False
            else:
                raise
    else:
        return None
Beispiel #2
0
def _main(name):
    starter = get_starter(name)
    print(name)
    if starter.url is not None:
        print(starter.get_checkpoint())
    else:
        print('No url')
        print('---------------------------------------')
Beispiel #3
0
def main(_):
    from tensorflow.python.tools.inspect_checkpoint import \
        print_tensors_in_checkpoint_file
    from slim_start import get_starter
    name = FLAGS.name
    starter = get_starter(name)
    latest_ckp = starter.get_checkpoint()
    print_tensors_in_checkpoint_file(latest_ckp,
                                     tensor_name='',
                                     all_tensors=False,
                                     all_tensor_names=True)
Beispiel #4
0
def _main(name):
    starter = get_starter(name)
    print(name)
    if starter.url is not None:
        if starter.clean_archive():
            print('Cleaned')
        else:
            print('No archive present')
    else:
        print('No url')
    print('---------------------------------------')
Beispiel #5
0
def main(_):
    import tensorflow as tf
    from slim_start import get_starter
    name = FLAGS.name
    starter = get_starter(name)
    image = tf.zeros((2, 224, 224, 3), dtype=tf.float32)
    with tf.contrib.slim.arg_scope(starter.get_scope()):
        out, endpoints = starter.get_unscoped_network_fn(
            num_classes=None)(image)

    if FLAGS.endpoints:
        for k, v in endpoints.items():
            print('%s: %s' % (k, str(v.shape)))
    else:
        vars = tf.get_collection(FLAGS.collection, scope=FLAGS.scope)
        for var in vars:
            print(var.name)
Beispiel #6
0
        loss = slim.losses.get_total_loss()

        optimizer = tf.train.AdamOptimizer()
        train_op = slim.learning.create_train_op(loss, optimizer)
        kwargs['loss'] = loss
        kwargs['eval_metric_ops'] = dict(accuracy=accuracy)
        kwargs['train_op'] = train_op
    return tf.estimator.EstimatorSpec(**kwargs)


name = 'mobilenet_v2'
weight_decay = 0.0
num_classes = 10
bn_decay = 0.9

starter = get_starter(name)
vars_to_warm_start = 'MobilenetV2/*'
warm_start_settings = tf.estimator.WarmStartSettings(starter.get_checkpoint(),
                                                     vars_to_warm_start)


def logits_fn(features, mode):
    is_training = mode == ModeKeys.TRAIN
    f = starter.get_scoped_network_fn(is_training=is_training,
                                      weight_decay=weight_decay,
                                      bn_decay=bn_decay)
    x, _ = f(features, base_only=True)
    x = tf.reduce_mean(x, axis=(1, 2))
    x = tf.layers.dense(x, num_classes)
    return x