def main():
    tf.logging.set_verbosity(tf.logging.INFO)
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--depth-multiplier', default=0.5, type=float)
    parser.add_argument('--scale-depthwise', default=1.0, type=float)
    parser.add_argument('--scale-pointwise', default=1.0, type=float)
    parser.add_argument('--scale-dense', default=1.0, type=float)
    parser.add_argument('--adj-power', default=None, type=float)
    parser.add_argument('--cutoff', default=None, type=float)
    parser.add_argument('--checkpoint-path', default=None, type=str)
    parser.add_argument('--train-dir', default=None, type=str)
    parser.add_argument('--dataset', default=None, type=str)
    args = parser.parse_args()

    scale = _make_scale(args.scale_depthwise,
                        args.scale_pointwise,
                        args.scale_dense,
                        args.adj_power,
                        args.cutoff,
                        args.depth_multiplier)

    network_fn = partial(mobilenet_v1,
                         num_classes=1001,
                         depth_multiplier=args.depth_multiplier,
                         scope=_make_scope(scale))

    model_fn = make_model_fn(network_fn, None)
    input_fn = make_input_fn(imagenet_data.get_split(args.dataset),
                             inception_preprocessing.preprocess_image,
                             n_epochs=1, image_size=224, batch_size=64)

    estimator = tf.estimator.Estimator(model_fn, model_dir=args.train_dir)
    estimator.evaluate(input_fn, checkpoint_path=args.checkpoint_path)
Exemplo n.º 2
0
def evaluate(args, model_fn):
    input_fn = make_input_fn(imagenet_data.get_split(args.dataset_path,
                                                     shuffle=False),
                             partial(inception_preprocessing.preprocess_image,
                                     is_training=False),
                             batch_size=64,
                             image_size=224)

    estimator = tf.estimator.Estimator(model_fn, model_dir=args.train_dir)
    estimator.evaluate(input_fn)
def train(args, model_fn):
    input_fn = make_input_fn(imagenet_data.get_split(args.dataset_path, shuffle=True),
                             partial(inception_preprocessing.preprocess_image, is_training=False),
                             batch_size=64, image_size=224)

    estimator = tf.estimator.Estimator(model_fn, model_dir=args.train_dir,
                                       params={'warm_start': args.warm_start})
    estimator.train(input_fn, max_steps=args.max_steps,
                    saving_listeners=[
                        CommitMaskedValueHook(args.max_steps,
                                              variables_fn=lambda: tf.get_collection(tf.GraphKeys.WEIGHTS))
                    ])
def evaluate(args, model_fn):
    input_fn = make_input_fn(imagenet_data.get_split(args.dataset_path, shuffle=False),
                             partial(inception_preprocessing.preprocess_image, is_training=False),
                             image_size=224,
                             n_epochs=1)

    estimator = tf.estimator.Estimator(model_fn, model_dir=args.train_dir)
    estimator.evaluate(input_fn)

    total = 0
    for v in tf.all_variables():
        total += np.prod(v.get_shape().as_list())
    print (total)
    print ((total * 4) / (1024 ** 2))
Exemplo n.º 5
0
def train(args, model_fn):
    input_fn = make_input_fn(imagenet_data.get_split(args.dataset_path,
                                                     shuffle=True),
                             partial(inception_preprocessing.preprocess_image,
                                     is_training=False),
                             batch_size=64,
                             image_size=224)

    estimator = tf.estimator.Estimator(model_fn,
                                       model_dir=args.train_dir,
                                       params={'warm_start': args.warm_start})

    estimator.train(
        input_fn,
        hooks=[
            ExecuteAtSessionCreateHook(op_fn=_quantize_init_op,
                                       name='QuantizeInitHook')
        ],
        saving_listeners=[
            quantization.CommitQuantizedValueHook(
                args.max_steps,
                variables_fn=lambda: tf.get_collection(tf.GraphKeys.WEIGHTS))
        ])