def test_mnist_registry(key: str, pretrained: Union[bool, str], test_input: bool): with tf_compat.Graph().as_default(): inputs = tf_compat.placeholder(tf_compat.float32, [None, 28, 28, 1], name="inputs") logits = ModelRegistry.create(key, inputs) with tf_compat.Session() as sess: if test_input: sess.run(tf_compat.global_variables_initializer()) out = sess.run( logits, feed_dict={inputs: numpy.random.random((1, 28, 28, 1))}) assert out.sum() != 0 if pretrained: ModelRegistry.load_pretrained(key, pretrained) if test_input: out = sess.run(logits, feed_dict={ inputs: numpy.random.random( (1, 28, 28, 1)) }) assert out.sum() != 0
def _load_model(args, sess, checkpoint_path=None): sess.run([ tf_compat.global_variables_initializer(), tf_compat.local_variables_initializer(), ]) checkpoint_path = checkpoint_path or args.checkpoint_path ModelRegistry.load_pretrained( args.arch_key, pretrained=args.pretrained, pretrained_dataset=args.pretrained_dataset, pretrained_path=checkpoint_path, sess=sess, ) if checkpoint_path: LOGGER.info( "Loaded model weights from checkpoint: {}".format(checkpoint_path))
def _save_checkpoint(args, sess, save_dir, checkpoint_name) -> str: checkpoint_path = os.path.join( os.path.join(save_dir, checkpoint_name, "model")) create_dirs(checkpoint_path) saver = ModelRegistry.saver(args.arch_key) saved_name = saver.save(sess, checkpoint_path) checkpoint_path = os.path.join(checkpoint_path, saved_name) LOGGER.info("Checkpoint saved to {}".format(checkpoint_path)) return checkpoint_path
def _create_model(args, num_classes, inputs, training=False): outputs = ModelRegistry.create( args.arch_key, inputs, training=training, num_classes=num_classes, **args.model_kwargs, ) LOGGER.info("created model {}".format(args.arch_key)) return outputs
def test_resnets(key: str, pretrained: Union[bool, str], test_input: bool, const: Callable): input_shape = ModelRegistry.input_shape(key) # test out the stand alone constructor with tf_compat.Graph().as_default(): inputs = tf_compat.placeholder(tf_compat.float32, [None, *input_shape], name="inputs") logits = const(inputs, training=False) if test_input: with tf_compat.Session() as sess: sess.run(tf_compat.global_variables_initializer()) out = sess.run( logits, feed_dict={inputs: numpy.random.random((1, *input_shape))}) assert out.sum() != 0 # test out the registry with tf_compat.Graph().as_default(): inputs = tf_compat.placeholder(tf_compat.float32, [None, *input_shape], name="inputs") logits = ModelRegistry.create(key, inputs, training=False) with tf_compat.Session() as sess: if test_input: sess.run(tf_compat.global_variables_initializer()) out = sess.run( logits, feed_dict={inputs: numpy.random.random((1, *input_shape))}) assert out.sum() != 0 if pretrained: ModelRegistry.load_pretrained(key, pretrained) if test_input: out = sess.run( logits, feed_dict={ inputs: numpy.random.random((1, *input_shape)) }, ) assert out.sum() != 0
def export( args, save_dir, checkpoint_path=None, skip_samples=False, num_classes=None, opset=None, ): assert not skip_samples or num_classes # dataset creation if not skip_samples: val_dataset, num_classes = _create_dataset(args, train=False) with tf_compat.Graph().as_default(): input_shape = ModelRegistry.input_shape(args.arch_key) inputs = tf_compat.placeholder(tf_compat.float32, [None] + list(input_shape), name="inputs") outputs = _create_model(args, num_classes, inputs) with tf_compat.Session() as sess: _load_model(args, sess, checkpoint_path=checkpoint_path or args.checkpoint_path) exporter = GraphExporter(save_dir) if not skip_samples: # Export a batch of samples and expected outputs tf_dataset = val_dataset.build(args.num_samples, repeat_count=1, num_parallel_calls=1) tf_iter = tf_compat.data.make_one_shot_iterator(tf_dataset) features, _ = tf_iter.get_next() inputs_val = sess.run(features) exporter.export_samples([inputs], [inputs_val], [outputs], sess) # Export model to tensorflow checkpoint format LOGGER.info("exporting tensorflow in {}".format(save_dir)) exporter.export_checkpoint(sess=sess) # Export model to pb format LOGGER.info("exporting pb in {}".format(exporter.pb_path)) exporter.export_pb(outputs=[outputs]) # Export model to onnx format LOGGER.info("exporting onnx in {}".format(exporter.onnx_path)) exporter.export_onnx([inputs], [outputs], opset=opset or args.onnx_opset)
def pruning_loss_sensitivity(args, save_dir): input_shape = ModelRegistry.input_shape(args.arch_key) train_dataset, num_classes = _create_dataset(args, train=True, image_size=input_shape[1]) with tf_compat.Graph().as_default() as graph: # create model graph inputs = tf_compat.placeholder(tf_compat.float32, [None] + list(input_shape), name="inputs") outputs = _create_model(args, num_classes, inputs) with tf_compat.Session() as sess: _load_model(args, sess, checkpoint_path=args.checkpoint_path) if args.approximate: LOGGER.info( "Running weight magnitude loss sensitivity analysis...") analysis = pruning_loss_sens_magnitude(graph, sess) else: op_vars = pruning_loss_sens_op_vars(graph) train_steps = math.ceil(len(train_dataset) / args.batch_size) train_dataset = _build_dataset(args, train_dataset, args.batch_size) handle, iterator, dataset_iter = create_split_iterators_handle( [train_dataset]) dataset_iter = dataset_iter[0] images, labels = iterator.get_next() loss = batch_cross_entropy_loss(outputs, labels) tensor_names = ["inputs:0", labels.name] sess.run(dataset_iter.initializer) def feed_dict_creator( step: int) -> Dict[str, tf_compat.Tensor]: assert step < train_steps batch_data = [ tens.eval(session=sess) for tens in dataset_iter.get_next() ] return dict(zip(tensor_names, batch_data)) LOGGER.info("Running one shot loss sensitivity analysis...") analysis = pruning_loss_sens_one_shot( op_vars=op_vars, loss_tensor=loss, steps_per_measurement=args.steps_per_measurement, feed_dict_creator=feed_dict_creator, sess=sess, ) # saving and printing results LOGGER.info("completed...") LOGGER.info("Saving results in {}".format(save_dir)) analysis.save_json( os.path.join( save_dir, "ks_approx_sensitivity.json" if args.approximate else "ks_one_shot_sensitivity.json", )) analysis.plot( os.path.join( save_dir, os.path.join( save_dir, "ks_approx_sensitivity.png" if args.approximate else "ks_one_shot_sensitivity.png", ), ), plot_integral=True, ) analysis.print_res()