Exemplo n.º 1
0
def test_mnist_registry(key: str, pretrained: Union[bool, str],
                        test_input: bool):
    with tf_compat.Graph().as_default():
        inputs = tf_compat.placeholder(tf_compat.float32, [None, 28, 28, 1],
                                       name="inputs")
        logits = ModelRegistry.create(key, inputs)

        with tf_compat.Session() as sess:
            if test_input:
                sess.run(tf_compat.global_variables_initializer())
                out = sess.run(
                    logits,
                    feed_dict={inputs: numpy.random.random((1, 28, 28, 1))})
                assert out.sum() != 0

            if pretrained:
                ModelRegistry.load_pretrained(key, pretrained)

                if test_input:
                    out = sess.run(logits,
                                   feed_dict={
                                       inputs: numpy.random.random(
                                           (1, 28, 28, 1))
                                   })
                    assert out.sum() != 0
def _load_model(args, sess, checkpoint_path=None):
    sess.run([
        tf_compat.global_variables_initializer(),
        tf_compat.local_variables_initializer(),
    ])
    checkpoint_path = checkpoint_path or args.checkpoint_path
    ModelRegistry.load_pretrained(
        args.arch_key,
        pretrained=args.pretrained,
        pretrained_dataset=args.pretrained_dataset,
        pretrained_path=checkpoint_path,
        sess=sess,
    )
    if checkpoint_path:
        LOGGER.info(
            "Loaded model weights from checkpoint: {}".format(checkpoint_path))
def _save_checkpoint(args, sess, save_dir, checkpoint_name) -> str:
    checkpoint_path = os.path.join(
        os.path.join(save_dir, checkpoint_name, "model"))
    create_dirs(checkpoint_path)
    saver = ModelRegistry.saver(args.arch_key)
    saved_name = saver.save(sess, checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path, saved_name)
    LOGGER.info("Checkpoint saved to {}".format(checkpoint_path))
    return checkpoint_path
def _create_model(args, num_classes, inputs, training=False):
    outputs = ModelRegistry.create(
        args.arch_key,
        inputs,
        training=training,
        num_classes=num_classes,
        **args.model_kwargs,
    )
    LOGGER.info("created model {}".format(args.arch_key))
    return outputs
Exemplo n.º 5
0
def test_resnets(key: str, pretrained: Union[bool, str], test_input: bool,
                 const: Callable):
    input_shape = ModelRegistry.input_shape(key)
    # test out the stand alone constructor
    with tf_compat.Graph().as_default():
        inputs = tf_compat.placeholder(tf_compat.float32, [None, *input_shape],
                                       name="inputs")
        logits = const(inputs, training=False)

        if test_input:
            with tf_compat.Session() as sess:
                sess.run(tf_compat.global_variables_initializer())
                out = sess.run(
                    logits,
                    feed_dict={inputs: numpy.random.random((1, *input_shape))})
                assert out.sum() != 0

    # test out the registry
    with tf_compat.Graph().as_default():
        inputs = tf_compat.placeholder(tf_compat.float32, [None, *input_shape],
                                       name="inputs")
        logits = ModelRegistry.create(key, inputs, training=False)

        with tf_compat.Session() as sess:
            if test_input:
                sess.run(tf_compat.global_variables_initializer())
                out = sess.run(
                    logits,
                    feed_dict={inputs: numpy.random.random((1, *input_shape))})
                assert out.sum() != 0

            if pretrained:
                ModelRegistry.load_pretrained(key, pretrained)

                if test_input:
                    out = sess.run(
                        logits,
                        feed_dict={
                            inputs: numpy.random.random((1, *input_shape))
                        },
                    )
                    assert out.sum() != 0
def export(
    args,
    save_dir,
    checkpoint_path=None,
    skip_samples=False,
    num_classes=None,
    opset=None,
):
    assert not skip_samples or num_classes
    # dataset creation
    if not skip_samples:
        val_dataset, num_classes = _create_dataset(args, train=False)

    with tf_compat.Graph().as_default():
        input_shape = ModelRegistry.input_shape(args.arch_key)
        inputs = tf_compat.placeholder(tf_compat.float32,
                                       [None] + list(input_shape),
                                       name="inputs")
        outputs = _create_model(args, num_classes, inputs)

        with tf_compat.Session() as sess:
            _load_model(args,
                        sess,
                        checkpoint_path=checkpoint_path
                        or args.checkpoint_path)

            exporter = GraphExporter(save_dir)

            if not skip_samples:
                # Export a batch of samples and expected outputs
                tf_dataset = val_dataset.build(args.num_samples,
                                               repeat_count=1,
                                               num_parallel_calls=1)
                tf_iter = tf_compat.data.make_one_shot_iterator(tf_dataset)
                features, _ = tf_iter.get_next()
                inputs_val = sess.run(features)
                exporter.export_samples([inputs], [inputs_val], [outputs],
                                        sess)

            # Export model to tensorflow checkpoint format
            LOGGER.info("exporting tensorflow in {}".format(save_dir))
            exporter.export_checkpoint(sess=sess)

            # Export model to pb format
            LOGGER.info("exporting pb in {}".format(exporter.pb_path))
            exporter.export_pb(outputs=[outputs])

    # Export model to onnx format
    LOGGER.info("exporting onnx in {}".format(exporter.onnx_path))
    exporter.export_onnx([inputs], [outputs], opset=opset or args.onnx_opset)
def pruning_loss_sensitivity(args, save_dir):
    input_shape = ModelRegistry.input_shape(args.arch_key)
    train_dataset, num_classes = _create_dataset(args,
                                                 train=True,
                                                 image_size=input_shape[1])
    with tf_compat.Graph().as_default() as graph:
        # create model graph
        inputs = tf_compat.placeholder(tf_compat.float32,
                                       [None] + list(input_shape),
                                       name="inputs")
        outputs = _create_model(args, num_classes, inputs)

        with tf_compat.Session() as sess:
            _load_model(args, sess, checkpoint_path=args.checkpoint_path)
            if args.approximate:
                LOGGER.info(
                    "Running weight magnitude loss sensitivity analysis...")
                analysis = pruning_loss_sens_magnitude(graph, sess)
            else:
                op_vars = pruning_loss_sens_op_vars(graph)
                train_steps = math.ceil(len(train_dataset) / args.batch_size)
                train_dataset = _build_dataset(args, train_dataset,
                                               args.batch_size)
                handle, iterator, dataset_iter = create_split_iterators_handle(
                    [train_dataset])
                dataset_iter = dataset_iter[0]
                images, labels = iterator.get_next()
                loss = batch_cross_entropy_loss(outputs, labels)
                tensor_names = ["inputs:0", labels.name]
                sess.run(dataset_iter.initializer)

                def feed_dict_creator(
                        step: int) -> Dict[str, tf_compat.Tensor]:
                    assert step < train_steps
                    batch_data = [
                        tens.eval(session=sess)
                        for tens in dataset_iter.get_next()
                    ]
                    return dict(zip(tensor_names, batch_data))

                LOGGER.info("Running one shot loss sensitivity analysis...")
                analysis = pruning_loss_sens_one_shot(
                    op_vars=op_vars,
                    loss_tensor=loss,
                    steps_per_measurement=args.steps_per_measurement,
                    feed_dict_creator=feed_dict_creator,
                    sess=sess,
                )
    # saving and printing results
    LOGGER.info("completed...")
    LOGGER.info("Saving results in {}".format(save_dir))
    analysis.save_json(
        os.path.join(
            save_dir,
            "ks_approx_sensitivity.json"
            if args.approximate else "ks_one_shot_sensitivity.json",
        ))
    analysis.plot(
        os.path.join(
            save_dir,
            os.path.join(
                save_dir,
                "ks_approx_sensitivity.png"
                if args.approximate else "ks_one_shot_sensitivity.png",
            ),
        ),
        plot_integral=True,
    )
    analysis.print_res()