def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():

        #######################
        # Quantizers          #
        #######################

        if FLAGS.intr_grad_quantizer is not '':
            qtype, qargs = utils.split_quantizer_str(FLAGS.intr_grad_quantizer)
            intr_grad_quantizer = utils.quantizer_selector(qtype, qargs)
        else:
            intr_grad_quantizer = None

        if FLAGS.extr_grad_quantizer is not '':
            qtype, qargs = utils.split_quantizer_str(FLAGS.extr_grad_quantizer)
            extr_grad_quantizer = utils.quantizer_selector(qtype, qargs)
        else:
            extr_grad_quantizer = None

        intr_q_map = utils.quantizer_map(FLAGS.intr_qmap)
        extr_q_map = utils.quantizer_map(FLAGS.extr_qmap)
        weight_q_map = utils.quantizer_map(FLAGS.weight_qmap)

        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            intr_q_map=intr_q_map,
            extr_q_map=extr_q_map,
            weight_q_map=weight_q_map)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                tf.losses.softmax_cross_entropy(
                    logits=end_points['AuxLogits'],
                    onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            tf.losses.softmax_cross_entropy(
                logits=logits,
                onehot_labels=labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # Add summaries for sparsity.
        weights_name_list, weights_list = utils.get_variables_list('weights')
        biases_name_list, biases_list = utils.get_variables_list('biases')
        for weight in weights_list:
            summaries.add(
                tf.summary.scalar('weight-sparsity/' + weight.name,
                                  tf.nn.zero_fraction(weight)))
        for bias in biases_list:
            summaries.add(
                tf.summary.scalar('weight-sparsity/' + bias.name,
                                  tf.nn.zero_fraction(bias)))
        # summaries for overall sparsity
        if weights_list is not []:
            weights_overall_sparsity = [
                tf.reshape(x, [tf.size(x)]) for x in weights_list
            ]
            weights_overall_sparsity = tf.concat(weights_overall_sparsity,
                                                 axis=0)
            summaries.add(
                tf.summary.scalar(
                    'weight-sparsity/weights-overall',
                    tf.nn.zero_fraction(weights_overall_sparsity)))
        if biases_list is not []:
            biases_overall_sparsity = [
                tf.reshape(x, [tf.size(x)]) for x in biases_list
            ]
            biases_overall_sparsity = tf.concat(biases_overall_sparsity,
                                                axis=0)
            summaries.add(
                tf.summary.scalar(
                    'weight-sparsity/biases-overall',
                    tf.nn.zero_fraction(biases_overall_sparsity)))

        # Add layerwise weight heatmaps
        for it in range(len(weights_name_list)):
            name = weights_name_list[it]
            weight = weights_list[it]
            if weight.get_shape().ndims == 4:
                image = utils.heatmap_conv(weight, pad=1)
            elif weight.get_shape().ndims == 2:
                image = utils.heatmap_fullyconnect(weight, pad=1)
            else:
                continue
            summaries.add(tf.summary.image(name, image))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate,
                                             intr_grad_quantizer)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        # quantize 'clones_gradients'
        if extr_grad_quantizer is not None:
            clones_gradients = [(extr_grad_quantizer.quantize(gv[0]), gv[1])
                                for gv in clones_gradients]

        # Add gradients to summary
        for gv in clones_gradients:
            summaries.add(
                tf.summary.histogram('gradient/%s' % gv[1].op.name, gv[0]))
            summaries.add(
                tf.summary.scalar('gradient-sparsity/%s' % gv[1].op.name,
                                  tf.nn.zero_fraction(gv[0])))

        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            #######################
            # Quantizers          #
            #######################

            if FLAGS.intr_grad_quantizer is not '':
                qtype, qargs = utils.split_quantizer_str(
                    FLAGS.intr_grad_quantizer)
                intr_grad_quantizer = utils.quantizer_selector(qtype, qargs)
            else:
                intr_grad_quantizer = None

            if FLAGS.extr_grad_quantizer is not '':
                qtype, qargs = utils.split_quantizer_str(
                    FLAGS.extr_grad_quantizer)
                extr_grad_quantizer = utils.quantizer_selector(qtype, qargs)
            else:
                extr_grad_quantizer = None

            intr_q_map = utils.quantizer_map(FLAGS.intr_qmap)
            extr_q_map = utils.quantizer_map(FLAGS.extr_qmap)
            weight_q_map = utils.quantizer_map(FLAGS.weight_qmap)
            print("Intr QMap:%s" % intr_q_map)

            # Create global_step
            global_step = tf.train.create_global_step()

            ######################
            # Select the dataset #
            ######################
            dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                  FLAGS.dataset_split_name,
                                                  FLAGS.dataset_dir)

            ######################
            # Select the network #
            ######################
            network_fn = nets_factory.get_network_fn(
                FLAGS.model_name,
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                weight_decay=FLAGS.weight_decay,
                is_training=True,
                intr_q_map=intr_q_map,
                extr_q_map=extr_q_map,
                weight_q_map=weight_q_map)

            ##############################################################
            # Create a dataset provider that loads data from the dataset #
            ##############################################################
            images = tf.placeholder(tf.float32,
                                    shape=[FLAGS.batch_size, 28, 28, 1])
            labels = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, 10])
            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            ####################
            # Define the model #
            ####################
            logits, end_points = network_fn(images)

            # Specify the loss function #
            loss = tf.losses.softmax_cross_entropy(
                logits=logits,
                onehot_labels=labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            regularizer_losses = tf.add_n(
                tf.losses.get_regularization_losses())
            total_loss = loss + regularizer_losses

            # Gather initial summaries.
            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by network_fn.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

            #############
            # Summaries #
            #############

            # Add summaries for end_points.
            for end_point in end_points:
                x = end_points[end_point]
                summaries.add(
                    tf.summary.histogram('activations/' + end_point, x))
                summaries.add(
                    tf.summary.scalar('sparse_activations/' + end_point,
                                      tf.nn.zero_fraction(x)))

            # Add summaries for losses.
            for loss in tf.get_collection(tf.GraphKeys.LOSSES):
                summaries.add(
                    tf.summary.scalar('losses/%s' % loss.op.name, loss))
            summaries.add(tf.summary.scalar('losses_total', total_loss))
            summaries.add(
                tf.summary.scalar('losses_regularization', regularizer_losses))
            summaries.add(tf.summary.scalar('losses_classification', loss))

            #########################################
            # Configure the optimization procedure. #
            #########################################
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate,
                                             intr_grad_quantizer)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            # Variables to train.
            variables_to_train = _get_variables_to_train()

            # Create gradient updates.
            # quantize 'clones_gradients'
            clones_gradients = optimizer.compute_gradients(total_loss)
            if extr_grad_quantizer is not None:
                clones_gradients = [(extr_grad_quantizer.quantize(gv[0]),
                                     gv[1]) for gv in clones_gradients]

            # Add gradients to summary
            for gv in clones_gradients:
                summaries.add(
                    tf.summary.histogram('gradient/%s' % gv[1].op.name, gv[0]))
                summaries.add(
                    tf.summary.scalar('gradient-sparsity/%s' % gv[1].op.name,
                                      tf.nn.zero_fraction(gv[0])))

            grad_updates = optimizer.apply_gradients(clones_gradients,
                                                     global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            train_tensor = control_flow_ops.with_dependencies([update_op],
                                                              total_loss,
                                                              name='train_op')

            # Ensemble related ops
            variables_to_ensemble = {
                v.name: tf.Variable(tf.zeros_like(v))
                for v in variables_to_train
            }
            ensemble_counter = tf.Variable(0.)
            variable_ensemble_ops = [
              tf.assign(variables_to_ensemble[v.name],
                       (variables_to_ensemble[v.name]*ensemble_counter + v)/(ensemble_counter + 1.)) \
              for v in variables_to_train
            ]
            ensemble_counter_update_op = tf.assign(ensemble_counter,
                                                   ensemble_counter + 1)
            ensemble_replace_ops = [
                tf.assign(v, variables_to_ensemble[v.name])
                for v in variables_to_train
            ]

            ##############################
            #  Evaluation pass  (Start)  #
            ##############################

            # Define the metrics:

            eval_pred = tf.squeeze(tf.argmax(logits, 1))
            eval_gtrs = tf.squeeze(tf.argmax(labels, 1))

            acc_value, acc_update = tf.metrics.accuracy(eval_pred, eval_gtrs)
            val_summary_lst = []
            val_summary_lst.append(
                tf.summary.scalar('val_acc', acc_value, collections=[]))
            val_summary_lst.append(
                tf.summary.scalar('val_err', 1 - acc_value, collections=[]))
            val_summary_lst.append(
                tf.summary.scalar('val_err_perc',
                                  100 * (1 - acc_value),
                                  collections=[]))
            val_summary = tf.summary.merge(val_summary_lst)

            num_batches = math.ceil(dataset.num_samples /
                                    (float(FLAGS.batch_size)))

            # Merge all summaries together.
            summary_op = tf.summary.merge(list(summaries))

            train_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

            ###########################
            # Kicks off the training. #
            ###########################
            sess.run(tf.global_variables_initializer())
            from tensorflow.examples.tutorials.mnist import input_data
            mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

            total_epoches = FLAGS.num_epoches
            for e in range(total_epoches):
                # Training pass
                sess.run(tf.local_variables_initializer())
                num_training_batches = 60000 // FLAGS.batch_size
                for i in range(num_training_batches):
                    batch_xs, batch_ys = mnist.train.next_batch(
                        FLAGS.batch_size)
                    if i % FLAGS.log_every_n_steps == 0 and i > 0:
                        summary_value, loss_value, acc = sess.run(
                            [summary_op, total_loss, acc_value],
                            feed_dict={
                                images:
                                np.reshape(batch_xs,
                                           (FLAGS.batch_size, 28, 28, 1)),
                                labels:
                                batch_ys
                            })
                        train_writer.add_summary(summary_value,
                                                 i + e * num_training_batches)

                        print("[%d/%d] loss %.5f err %.3f%%"\
                             %(i, num_training_batches, loss_value, (1. - acc)*100))
                        sess.run(tf.local_variables_initializer())

                    sess.run(
                        [update_op, acc_update],
                        feed_dict={
                            images:
                            np.reshape(batch_xs,
                                       (FLAGS.batch_size, 28, 28, 1)),
                            labels:
                            batch_ys
                        })

                # Validation pass
                sess.run(tf.local_variables_initializer())
                for i in range(10000 // FLAGS.batch_size):
                    batch_xs, batch_ys = mnist.validation.next_batch(
                        FLAGS.batch_size)
                    sess.run(
                        [acc_update],
                        feed_dict={
                            images:
                            np.reshape(batch_xs,
                                       (FLAGS.batch_size, 28, 28, 1)),
                            labels:
                            batch_ys
                        })
                val_acc, val_summary_value = sess.run(
                    [acc_value, val_summary],
                    feed_dict={
                        images: np.reshape(batch_xs,
                                           (FLAGS.batch_size, 28, 28, 1)),
                        labels: batch_ys,
                    })
                print("Epoch[%d/%d] : ValErr:%.3f%%" % (e, total_epoches,
                                                        (1 - val_acc) * 100))
                train_writer.add_summary(val_summary_value, e)

                # Ensemble pass
                # if (e+1) % 5 == 0:
                if (e + 1) % 20 == 0:
                    sess.run(variable_ensemble_ops)
                    sess.run(ensemble_counter_update_op)
                    print("Ensembled epoch %d weights" % e)

            # Validation Pass for Weight Ensembled Networks
            sess.run(tf.local_variables_initializer())
            sess.run(ensemble_replace_ops)
            for i in range(10000 // FLAGS.batch_size):
                batch_xs, batch_ys = mnist.validation.next_batch(
                    FLAGS.batch_size)
                sess.run(
                    [acc_update],
                    feed_dict={
                        images: np.reshape(batch_xs,
                                           (FLAGS.batch_size, 28, 28, 1)),
                        labels: batch_ys
                    })
            val_acc, val_summary_value = sess.run(
                [acc_value, val_summary],
                feed_dict={
                    images: np.reshape(batch_xs,
                                       (FLAGS.batch_size, 28, 28, 1)),
                    labels: batch_ys,
                })
            print("Ensembled network ValErro:%.3f%%" % ((1 - val_acc) * 100))
            train_writer.add_summary(val_summary_value, total_epoches)
Exemple #3
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)  #can be WARN or INFO
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ########################
        # Determine Quantizers #
        ########################
        intr_q_map = utils.quantizer_map(FLAGS.intr_qmap)
        extr_q_map = utils.quantizer_map(FLAGS.extr_qmap)
        weight_q_map = utils.quantizer_map(FLAGS.weight_qmap)

        ####################
        # Select the model #
        ####################
        if 'resnet' in FLAGS.model_name:
            labels_offset = 1
        else:
            labels_offset = 0
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - labels_offset),
            is_training=False,
            intr_q_map=intr_q_map,
            extr_q_map=extr_q_map,
            weight_q_map=weight_q_map)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            common_queue_capacity=8 * FLAGS.batch_size,
            common_queue_min=FLAGS.batch_size * 4)
        [image, label] = provider.get(['image', 'label'])
        label -= labels_offset

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)
        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

        ####################
        # Define the model #
        ####################
        start_time_build = time.time()

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)

        logits, endpoints = network_fn(images)
        predictions = tf.argmax(logits, 1)
        labels = tf.squeeze(labels)
        used_gpus = 1

        #tf.logging.info('Number of parameters per layer and endpoint:')
        #for var in endpoints:
        #    tf.logging.info('%s: %d'%(var,utils.count_trainable_params(var)))
        #print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
            #'Recall_5': slim.metrics.streaming_recall_at_k(logits, labels, 5),
        })

        # Print the summaries to screen.
        for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        # TODO(sguada) use num_epochs=1
        if FLAGS.max_num_batches and FLAGS.max_num_batches > 0:
            num_batches = FLAGS.max_num_batches
        else:
            # This ensures that we make a single pass over all of the data.
            #num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
            num_batches = math.ceil(dataset.num_samples /
                                    (float(FLAGS.batch_size) * used_gpus))

        # get checkpoint
        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        else:
            checkpoint_path = FLAGS.checkpoint_path
        if checkpoint_path is None:
            raise ValueError('No Checkpoint found!')
        tf.logging.info('Evaluating %s' % checkpoint_path)

        # get the quantized weight tensors for sparsity estimation
        weights_name_list, weights_list = utils.get_variables_list('weights')
        biases_name_list, biases_list = utils.get_variables_list('biases')

        # count number of elements in each layer
        weights_list_param_count = [
            utils.get_nb_params_shape(x.get_shape()) for x in weights_list
        ]
        weights_list_param_count = dict(
            zip(weights_name_list, weights_list_param_count))
        biases_list_param_count = [
            utils.get_nb_params_shape(x.get_shape()) for x in biases_list
        ]
        biases_list_param_count = dict(
            zip(biases_name_list, biases_list_param_count))

        # get zero fraction for each layer
        weights_layerwise_sparsity_op = [
            tf.nn.zero_fraction(x) for x in weights_list
        ]
        biases_layerwise_sparsity_op = [
            tf.nn.zero_fraction(x) for x in biases_list
        ]

        # add overall weights sparsity to summary
        weights_overall_sparsity_op = [
            tf.reshape(x, [tf.size(x)]) for x in weights_list
        ]
        weights_overall_sparsity_op = tf.concat(weights_overall_sparsity_op,
                                                axis=0)
        summary_name = 'eval/weight_overall_sparsity'
        weights_overall_sparsity_op = tf.nn.zero_fraction(
            weights_overall_sparsity_op)
        op = tf.summary.scalar(summary_name,
                               weights_overall_sparsity_op,
                               collections=[])
        op = tf.Print(op, [value], summary_name)
        tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        # add overall bias sparsity to summary
        biases_overall_sparsity_op = [
            tf.reshape(x, [tf.size(x)]) for x in biases_list
        ]
        biases_overall_sparsity_op = tf.concat(biases_overall_sparsity_op,
                                               axis=0)
        summary_name = 'eval/biases_overall_sparsity'
        biases_overall_sparsity_op = tf.nn.zero_fraction(
            biases_overall_sparsity_op)
        op = tf.summary.scalar(summary_name,
                               biases_overall_sparsity_op,
                               collections=[])
        op = tf.Print(op, [value], summary_name)
        tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        # Run Session
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        #config.gpu_options.allocator_type = 'BFC'

        # Final ops, used for statistics
        final_op = (list(names_to_values.values()),
                    weights_layerwise_sparsity_op,
                    biases_layerwise_sparsity_op, weights_overall_sparsity_op,
                    biases_overall_sparsity_op)
        print('Running %s for %d iterations' % (FLAGS.model_name, num_batches))
        start_time_simu = time.time()
        run_values = slim.evaluation.evaluate_once(
            master=FLAGS.master,
            checkpoint_path=checkpoint_path,
            logdir=FLAGS.eval_dir,
            num_evals=num_batches,
            eval_op=list(names_to_updates.values()),
            final_op=final_op,
            variables_to_restore=variables_to_restore,
            session_config=config)
        runtime = time.time() - start_time_simu
        buildtime = start_time_simu - start_time_build
        accuracy = run_values[0][0]
        weight_sparsity_values = run_values[1]
        bias_sparsity_values = run_values[2]
        weight_sparsity = run_values[3]
        bias_sparsity = run_values[4]

        # compute sparsity
        print('Calculating sparsity...')
        weight_sparsity_layerwise = {}
        for it in range(len(weights_name_list)):
            name = weights_name_list[it]
            weight_sparsity_layerwise[name] = float(weight_sparsity_values[it])
        bias_sparsity_layerwise = {}
        for it in range(len(biases_name_list)):
            name = biases_name_list[it]
            bias_sparsity_layerwise[name] = float(bias_sparsity_values[it])

        # print statistics
        print('\nStatistics:')
        print('Accuracy: %.2f%%' % (accuracy * 100))
        print('Buildtime: %f sec' % buildtime)
        print('Runtime: %f sec' % runtime)
        print('Weight sparsity: %f%%' % (weight_sparsity * 100))
        print('Layerwise weight sparsity:')
        for key in weight_sparsity_layerwise.keys():
            print("     %s: %.2f%%" %
                  (key, weight_sparsity_layerwise[key] * 100))
        print('Bias sparsity: %f%%' % (bias_sparsity * 100))
        print('Layerwise bias sparsity:')
        for key in bias_sparsity_layerwise.keys():
            print("     %s: %.2f%%" %
                  (key, bias_sparsity_layerwise[key] * 100))
        print('Comment: %s' % (FLAGS.comment))

        # tf.train.export_meta_graph(filename=FLAGS.checkpoint_path+'/model.meta')

        # write data to .json file
        if FLAGS.output_file is not None:
            if os.path.exists(FLAGS.output_file) == False:
                json_data = []
            else:
                with open(FLAGS.output_file) as hfile:
                    json_data = json.load(hfile)
            new_data = {}
            new_data["accuracy"] = accuracy.tolist()
            new_data["net"] = FLAGS.model_name
            new_data["samples"] = (num_batches * FLAGS.batch_size * used_gpus)
            new_data["weight_sparsity"] = weight_sparsity.tolist()
            new_data["weight_sparsity_layerwise"] = weight_sparsity_layerwise
            new_data["weight_count_layerwise"] = weights_list_param_count
            new_data["bias_sparsity"] = bias_sparsity.tolist()
            new_data["bias_sparsity_layerwise"] = bias_sparsity_layerwise
            new_data["bias_count_layerwise"] = biases_list_param_count
            new_data["runtime"] = runtime
            new_data["comment"] = FLAGS.comment

            json_data.append(new_data)
            with open(FLAGS.output_file, 'w') as hfile:
                json.dump(json_data, hfile)

        print('Done.')
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO) #can be WARN or INFO
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ########################
    # Determine Quantizers #
    ########################
    intr_q_map=utils.quantizer_map(FLAGS.intr_qmap)
    extr_q_map=utils.quantizer_map(FLAGS.extr_qmap)
    weight_q_map=utils.quantizer_map(FLAGS.weight_qmap)
    
    ####################
    # Select the model #
    ####################
    if 'resnet' in FLAGS.model_name:
        labels_offset=1
    else:
        labels_offset=0
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - labels_offset),
        is_training=False,
        intr_q_map=intr_q_map, extr_q_map=extr_q_map,
        weight_q_map=weight_q_map)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=8 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size*4)
    [image, label] = provider.get(['image', 'label'])
    label -= labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)
    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    ####################
    # Define the model #
    ####################
    start_time_build = time.time()
 
    images, labels = tf.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)
    
    logits, endpoints = network_fn(images)
    predictions= tf.argmax(logits, 1)
    labels = tf.squeeze(labels)
    used_gpus=1
    
    #tf.logging.info('Number of parameters per layer and endpoint:')
    #for var in endpoints:
    #    tf.logging.info('%s: %d'%(var,utils.count_trainable_params('InceptionV3/'+var)))
    
    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        #'Recall_5': slim.metrics.streaming_recall_at_k(logits, labels, 5),
    })

    # Print the summaries to screen.
    for name, value in names_to_values.items():
      summary_name = 'eval/%s' % name
      op = tf.summary.scalar(summary_name, value, collections=[])
      op = tf.Print(op, [value], summary_name)
      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1    
    if FLAGS.max_num_batches and FLAGS.max_num_batches > 0:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      #num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
      num_batches = math.ceil(dataset.num_samples / (float(FLAGS.batch_size)*used_gpus) )

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path
    if checkpoint_path is None:
        raise ValueError('No Checkpoint found!')
    tf.logging.info('Evaluating %s' % checkpoint_path)

    config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    #config.gpu_options.allocator_type = 'BFC'
    # Run Session

    print('Running %s for %d iterations'%(FLAGS.model_name,num_batches))

    start_time_simu = time.time()
    metric_values = slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()), 
        final_op=list(names_to_values.values()),
        variables_to_restore=variables_to_restore,
        session_config=config)
    runtime=time.time()-start_time_simu
    buildtime=start_time_simu-start_time_build
    print('Buildtime: %f sec'%buildtime)
    print('Runtime: %f sec'%runtime)

    # tf.train.export_meta_graph(filename=FLAGS.checkpoint_path+'/model.meta')

    # write data to .json file
    if FLAGS.output_file is not None:
      print('Writing results to file %s'%(FLAGS.output_file))
      with open(FLAGS.output_file,'a') as hfile:
        hfile.write( "{\n")
        hfile.write( '  "accuracy":%f,\n'%(metric_values[0]) )
        hfile.write( '  "net":"%s",\n'%(FLAGS.model_name) )
        hfile.write( '  "samples":%d,\n'%(num_batches*FLAGS.batch_size*used_gpus) )
        hfile.write( '  "comment":%s,\n'%(FLAGS.comment) )
#        hfile.write( '  "intr_q_w":%d,\n'%(intr_quant_width) )
#        hfile.write( '  "intr_q_f":%d,\n'%(intr_quant_prec) )
#        hfile.write( '  "intr_layers":"%s",\n'%(FLAGS.intr_quantize_layers) )
#        hfile.write( '  "intr_rounding":"%s",\n'%(intr_rounding) )
#        hfile.write( '  "extr_q_w":%d,\n'%(extr_quant_width) )
#        hfile.write( '  "extr_q_f":%d,\n'%(extr_quant_prec) )
#        hfile.write( '  "extr_layers":"%s",\n'%(FLAGS.extr_quantize_layers) )
#        hfile.write( '  "extr_rounding":"%s",\n'%(extr_rounding) )
        hfile.write( '  "runtime":%f\n'%(runtime) )
        hfile.write( "}\n")