Пример #1
0
def getImageBatchAndOneHotLabels(dataset_dir, dataset_name, num_readers,
                                 num_preprocessing_threads, batch_size):
    '''
    :param dataset_dir: directory where the tfrecord files are stored
    :param dataset_name: name of the dataset e.g. train / validation
    :return:
    '''
    dataset = imagenet.get_split(dataset_name, dataset_dir)
    # DataSetProvider on CPU
    with tf.device('/device:CPU:0'):
        # ------- Dataset Provider ---------
        provider_train = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=num_readers,
            common_queue_capacity=2 * batch_size,
            common_queue_min=batch_size)
        [image, label] = provider_train.get(['image', 'label'])

    # Preprocessing of Dataset
    train_image_size = alexnet.alexnet_v2.default_image_size
    image = alexnet_preprocessing.preprocess_image(image, train_image_size,
                                                   train_image_size)

    # Generate Batches
    images, labels = tf.train.batch([image, label],
                                    batch_size=batch_size,
                                    num_threads=num_preprocessing_threads,
                                    capacity=5 * batch_size)
    labels = slim.one_hot_encoding(labels, dataset.num_classes)
    return dataset, images, labels
Пример #2
0
def get_alpha_from_compression_rate(compression_rate):
    def bin_search_k(shape_lst, cr_target):
        import math
        high = 1.0
        low = 0.001
        alpha = (high + low) / 2
        while True:
            nparams_orig = 0
            nparams_comp = 0
            for shape in shape_lst:
                n, c, w = shape[3], shape[2], shape[1]
                nparams_orig += n * c * w * w * 32
                k = int(alpha * n * c * w)
                nparams_comp += (n * c * w * int(math.log(k, 2) + 1) +
                                 k * w * 32)
            cr = nparams_orig / nparams_comp
            delta = cr_target - cr
            print('delta: {}'.format(delta))
            print('alpha: {}'.format(alpha))
            print('###################')
            if math.fabs(delta) < 0.02:
                break
            if delta < 0.0:
                low = (high + low) / 2
                alpha = (high + low) / 2
            else:
                high = (high + low) / 2
                alpha = (high + low) / 2
        return alpha

    from functools import reduce
    model_name = 'inception_v1'

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            dataset = imagenet.get_split('validation', FLAGS.dataset_dir)
            istraining_placeholder = tf.placeholder(tf.bool)
            network_fn = GoogLeNet.GoogLeNet(
                model_name,
                num_classes=(dataset.num_classes - 1),
                weight_decay=FLAGS.weight_decay,
                is_training=istraining_placeholder)
            logits_lst = []
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.batch_size * FLAGS.gpu_num,
                       network_fn.default_image_size,
                       network_fn.default_image_size, 3))
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, 1):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            logits, end_points, end_points_Ofmap, end_points_Ifmap = (
                                images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size])
                            logits_lst.append(logits)
                            #end_points_lst.append(end_points)

                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)

            shape_lst = []

            for v in tf.trainable_variables():
                if 'weights' in v.name:
                    #print(sess.run(v).shape)
                    print("v.name: {}".format(v.name))
                    print("v.shape: {}".format(sess.run(v).shape))
                    print('=====================================')
                    shape_lst.append(sess.run(v).shape)

                if 'beta' in v.name:
                    print("v.name: {}".format(v.name))
                    print("v.shape: {}".format(sess.run(v).shape))
                    print('+++++++++++++++++++++++++++++++++++++')
            alpha = bin_search_k(shape_lst, compression_rate)

            print("compression rate: {:.4f}, alpha: {:.4f}".format(
                compression_rate, alpha))
Пример #3
0
def run_gpu_eval(use_compression=False,
                 use_quantization=False,
                 compute_energy=False,
                 use_pretrained_model=True,
                 epoch_num=0):
    from functools import reduce
    module_name = 'inception_v1'
    checkpoint_dir = 'checkpoint/{}_{}_{}'.format(module_name, epoch_num,
                                                  FLAGS.alpha)

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            dataset = imagenet.get_split('validation', FLAGS.dataset_dir)
            istraining_placeholder = tf.placeholder(tf.bool)
            network_fn = GoogLeNet.GoogLeNet(
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                weight_decay=FLAGS.weight_decay,
                is_training=istraining_placeholder)
            logits_lst = []
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.batch_size * FLAGS.gpu_num,
                       network_fn.default_image_size,
                       network_fn.default_image_size, 3))
            labels_placeholder = tf.placeholder(
                tf.int64,
                shape=(FLAGS.batch_size * FLAGS.gpu_num, dataset.num_classes))
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            logits, end_points, end_points_Ofmap, end_points_Ifmap = network_fn(
                                images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size])
                            logits_lst.append(logits)
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

            image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                is_training=False)

            logits_op = tf.concat(logits_lst, 0)
            right_count_top1_op = tf.reduce_sum(
                tf.cast(
                    tf.equal(tf.argmax(tf.nn.softmax(logits_op), axis=1),
                             tf.argmax(labels_placeholder, axis=1)), tf.int32))
            right_count_topk_op = tf.reduce_sum(
                tf.cast(
                    tf.nn.in_top_k(tf.nn.softmax(logits_op),
                                   tf.argmax(labels_placeholder, axis=1), 5),
                    tf.int32))

            images_op, labels_op = input_data.inputs(
                dataset=dataset,
                image_preprocessing_fn=image_preprocessing_fn,
                network_fn=network_fn,
                num_epochs=1,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num)

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            print([var.name for var in bn_moving_vars])

            if use_pretrained_model:
                varlist = tf.trainable_variables()
                varlist += bn_moving_vars
                print(varlist)
                saver = tf.train.Saver(varlist)
                # saver = tf.train.Saver(vardict)
                if os.path.isfile(FLAGS.checkpoint_path):
                    saver.restore(sess, FLAGS.checkpoint_path)
                    print(
                        '#############################Session restored from pretrained model at {}!###############################'
                        .format(FLAGS.checkpoint_path))
                else:
                    ckpt = tf.train.get_checkpoint_state(
                        checkpoint_dir=FLAGS.checkpoint_path)
                    if ckpt and ckpt.model_checkpoint_path:
                        saver = tf.train.Saver(varlist)
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        print(
                            'Session restored from pretrained degradation model at {}!'
                            .format(ckpt.model_checkpoint_path))
            else:
                varlist = tf.trainable_variables()
                varlist += bn_moving_vars
                saver = tf.train.Saver(varlist)
                ckpt = tf.train.get_checkpoint_state(
                    checkpoint_dir=checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print(
                        '#############################Session restored from trained model at {}!###############################'
                        .format(ckpt.model_checkpoint_path))
                else:
                    raise FileNotFoundError(errno.ENOENT,
                                            os.strerror(errno.ENOENT),
                                            checkpoint_dir)

            mat_eng = matlab.engine.start_matlab()
            seed = 500
            alpha = FLAGS.alpha
            memory = 0
            for v in tf.trainable_variables() + bn_moving_vars:
                if 'weights' in v.name:
                    memory += np.prod(sess.run(v).shape)
                    print("weights.name: {}".format(v.name))
                    print("weights.shape: {}".format(sess.run(v).shape))
                    if use_compression:
                        weights = np.transpose(sess.run(v), (3, 2, 1, 0))
                        shape = weights.shape
                        n, c, w = shape[0], shape[1], shape[2]
                        k = int(alpha * n * c * w)
                        weight_clustered, mse = cluster_conv(weights, k, seed)
                        weight_clustered = np.transpose(
                            weight_clustered, (3, 2, 1, 0))
                        sess.run(v.assign(weight_clustered))
                        print("weight_clustered shape: {}".format(
                            weight_clustered.shape))
                        print("mse: {}".format(mse))
                        seed += 1
                    if use_quantization:
                        weights = np.transpose(sess.run(v), (3, 2, 1, 0))
                        shape = weights.shape
                        weight_quantized = mat_eng.get_fi(
                            matlab.double(weights.tolist()), FLAGS.bitwidth,
                            FLAGS.bitwidth -
                            FLAGS.bitwidth_minus_fraction_length)
                        weight_quantized = np.asarray(
                            weight_quantized).reshape(shape).astype('float32')
                        weight_quantized = np.transpose(
                            weight_quantized, (3, 2, 1, 0))
                        sess.run(v.assign(weight_quantized))
                        print("weight_quantized shape: {}".format(
                            weight_quantized.shape))
                    print('=====================================')

                if any(x in v.name for x in ['beta']):
                    memory += np.prod(sess.run(v).shape)
                    print("beta.name: {}".format(v.name))
                    print("beta.shape: {}".format(sess.run(v).shape))
                    if use_quantization:
                        weights = sess.run(v)
                        shape = weights.shape
                        weight_quantized = mat_eng.get_fi(
                            matlab.double(weights.tolist()), FLAGS.bn_bitwidth,
                            FLAGS.bn_bitwidth -
                            FLAGS.bitwidth_minus_fraction_length)
                        weight_quantized = np.asarray(
                            weight_quantized).reshape(shape).astype('float32')
                        sess.run(v.assign(weight_quantized))
                        print("beta_quantized shape: {}".format(
                            weight_quantized.shape))
                    print('+++++++++++++++++++++++++++++++++++++')

            checkpoint_path = os.path.join(checkpoint_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=0)
            print(
                "############################################### MEMORY IS {} ###############################################"
                .format(memory))

            if compute_energy:
                weights_dict = {}
                for v in tf.trainable_variables():
                    if 'weights' in v.name:
                        vname = "_".join(v.name.split('/')[1:-1])
                        print("v.name: {}".format(vname))
                        print("v.shape: {}".format(sess.run(v).shape))
                        #weights = np.transpose(sess.run(v), (3, 2, 1, 0))
                        weights = sess.run(v)
                        print("v.nzeros: {}".format(
                            np.count_nonzero(weights == 0)))
                        weights_dict[vname] = [
                            reduce(lambda x, y: x * y, weights.shape) *
                            (1 - FLAGS.alpha), weights.shape
                        ]
                        print('=====================================')

            total_v = 0.0
            test_correct_num_top1 = 0.0
            test_correct_num_topk = 0.0

            from tqdm import tqdm

            pbar = tqdm(total=dataset.num_samples //
                        (FLAGS.gpu_num * FLAGS.batch_size), )
            i = 1
            model_params_dict = {}
            try:
                while not coord.should_stop():
                    pbar.update(1)
                    images, labels = sess.run([images_op, labels_op])

                    right_count_top1, right_count_topk = sess.run(
                        [right_count_top1_op, right_count_topk_op],
                        feed_dict={
                            images_placeholder: images,
                            labels_placeholder: labels,
                            istraining_placeholder: False
                        })

                    end_points_Ofmap_dict, end_points_Ifmap_dict = sess.run(
                        [end_points_Ofmap, end_points_Ifmap],
                        feed_dict={
                            images_placeholder: images,
                            labels_placeholder: labels,
                            istraining_placeholder: False
                        })

                    test_correct_num_top1 += right_count_top1
                    test_correct_num_topk += right_count_topk
                    total_v += labels.shape[0]

                    if compute_energy:
                        keys = list(end_points_Ifmap_dict.keys())
                        if i == 1:
                            for k in keys:
                                model_params_dict[k] = {}
                                model_params_dict[k][
                                    "IfMap_Shape"] = end_points_Ifmap_dict[
                                        k].shape
                                model_params_dict[k][
                                    "IfMap_nZeros"] = np.count_nonzero(
                                        end_points_Ifmap_dict[k] == 0)

                                model_params_dict[k][
                                    "Filter_Shape"] = weights_dict[k][1]
                                model_params_dict[k]["Filter_nZeros"] = int(
                                    weights_dict[k][0])

                                model_params_dict[k][
                                    "OfMap_Shape"] = end_points_Ofmap_dict[
                                        k].shape
                                model_params_dict[k][
                                    "OfMap_nZeros"] = np.count_nonzero(
                                        end_points_Ofmap_dict[k] == 0)
                                print("Layer Name: {}".format(k))
                                print("IfMap Shape: {}".format(
                                    end_points_Ifmap_dict[k].shape))
                                print("IfMap nZeros: {:.4e}".format(
                                    np.count_nonzero(
                                        end_points_Ifmap_dict[k] == 0)))
                                print("IfMap nZeros Avg: {:.4e}".format(
                                    model_params_dict[k]["IfMap_nZeros"]))
                                print("Filter Shape: {}".format(
                                    weights_dict[k][1]))
                                print("Filter nZeros: {:.4e}".format(
                                    int(weights_dict[k][0])))
                                print("OfMap Shape: {}".format(
                                    end_points_Ofmap_dict[k].shape))
                                print("OfMap nZeros: {:.4e}".format(
                                    np.count_nonzero(
                                        end_points_Ofmap_dict[k] == 0)))
                                print("OfMap nZeros Avg: {:.4e}".format(
                                    model_params_dict[k]["OfMap_nZeros"]))
                                print(
                                    '=========================================================================='
                                )
                        else:
                            for k in keys:
                                model_params_dict[k]["IfMap_nZeros"] = (
                                    model_params_dict[k]["IfMap_nZeros"] +
                                    np.count_nonzero(
                                        end_points_Ifmap_dict[k] == 0) /
                                    (i - 1)) * (i - 1) / i
                                model_params_dict[k]["OfMap_nZeros"] = (
                                    model_params_dict[k]["OfMap_nZeros"] +
                                    np.count_nonzero(
                                        end_points_Ofmap_dict[k] == 0) /
                                    (i - 1)) * (i - 1) / i
                        i += 1
            except tf.errors.OutOfRangeError:
                print('Done testing on all the examples')
            finally:
                coord.request_stop()
                if compute_energy:
                    import pickle
                    with open('model_params_dict.pkl', 'wb') as f:
                        pickle.dump(model_params_dict, f,
                                    pickle.HIGHEST_PROTOCOL)
                    with open('GoogLeNet_Pruned_{}.txt'.format(FLAGS.alpha),
                              'w') as wf:
                        for k in keys:
                            wf.write("Layer Name: {}\n".format(k))

                            wf.write("IfMap Shape: {}\n".format(
                                model_params_dict[k]["IfMap_Shape"]))
                            wf.write("IfMap nZeros: {:.4e}\n".format(
                                model_params_dict[k]["IfMap_nZeros"]))

                            wf.write("Filter Shape: {}\n".format(
                                model_params_dict[k]["Filter_Shape"]))
                            wf.write("Filter nZeros: {:.4e}\n".format(
                                model_params_dict[k]["Filter_nZeros"]))

                            wf.write("OfMap Shape: {}\n".format(
                                model_params_dict[k]["OfMap_Shape"]))
                            wf.write("OfMap nZeros: {:.4e}\n".format(
                                model_params_dict[k]["OfMap_nZeros"]))
                            wf.write(
                                '==========================================================================\n'
                            )
            coord.join(threads)
            print('Test acc top1:', test_correct_num_top1 / total_v,
                  'Test_correct_num top1:', test_correct_num_top1, 'Total_v:',
                  total_v)
            print('Test acc topk:', test_correct_num_topk / total_v,
                  'Test_correct_num topk:', test_correct_num_topk, 'Total_v:',
                  total_v)

            isCompression = lambda bool: "Compression_" if bool else "NoCompression_"
            isQuantization = lambda bool: "Quantization_" if bool else "NoQuantization_"
            with open(
                    '{}_{}_{}_{}_evaluation.txt'.format(
                        isCompression(use_compression),
                        isQuantization(use_quantization), epoch_num,
                        FLAGS.alpha), 'w') as wf:
                wf.write(
                    'test acc top1:{}\ttest_correct_num top1:{}\ttotal_v:{}\n'.
                    format(test_correct_num_top1 / total_v,
                           test_correct_num_top1, total_v))
                wf.write(
                    'test acc topk:{}\ttest_correct_num topk:{}\ttotal_v:{}\n'.
                    format(test_correct_num_topk / total_v,
                           test_correct_num_topk, total_v))

    print("done")
Пример #4
0
def run_gpu_train(use_pretrained_model, epoch_num):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    module_name = 'inception_v1'
    checkpoint_dir = 'checkpoint/{}_{}_{}'.format(module_name, epoch_num - 1,
                                                  FLAGS.alpha)

    saved_checkpoint_dir = 'checkpoint/{}_{}_{}'.format(
        module_name, epoch_num, FLAGS.alpha)

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if not os.path.exists(saved_checkpoint_dir):
        os.makedirs(saved_checkpoint_dir)
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            dataset = imagenet.get_split('train', FLAGS.dataset_dir)
            dataset_val = imagenet.get_split('validation', FLAGS.dataset_dir)
            global_step = slim.create_global_step()
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            istraining_placeholder = tf.placeholder(tf.bool)
            network_fn = GoogLeNet.GoogLeNet(
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                weight_decay=FLAGS.weight_decay,
                is_training=istraining_placeholder)
            tower_grads = []
            logits_lst = []
            losses_lst = []
            opt = _configure_optimizer(learning_rate)
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.batch_size * FLAGS.gpu_num,
                       network_fn.default_image_size,
                       network_fn.default_image_size, 3))
            labels_placeholder = tf.placeholder(
                tf.int64,
                shape=(FLAGS.batch_size * FLAGS.gpu_num, dataset.num_classes))
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            logits, _, _, _ = network_fn(
                                images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size])
                            logits_lst.append(logits)
                            loss = tower_loss_xentropy_dense(
                                logits, labels_placeholder[gpu_index *
                                                           FLAGS.batch_size:
                                                           (gpu_index + 1) *
                                                           FLAGS.batch_size])
                            losses_lst.append(loss)
                            # varlist = [v for v in tf.trainable_variables() if any(x in v.name for x in ["logits"])]
                            varlist = tf.trainable_variables()
                            #print([v.name for v in varlist])
                            grads = opt.compute_gradients(loss, varlist)
                            tower_grads.append(grads)
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

            image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                is_training=True)
            val_image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                is_training=False)

            loss_op = tf.reduce_mean(losses_lst, name='softmax')
            logits_op = tf.concat(logits_lst, 0)
            acc_op = accuracy(logits_op, labels_placeholder)
            grads = average_gradients(tower_grads)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print(update_ops)
            with tf.control_dependencies([tf.group(*update_ops)]):
                apply_gradient_op = opt.apply_gradients(
                    grads, global_step=global_step)

            images_op, labels_op = input_data.inputs(
                dataset=dataset,
                image_preprocessing_fn=image_preprocessing_fn,
                network_fn=network_fn,
                num_epochs=1,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num)
            val_images_op, val_labels_op = input_data.inputs(
                dataset=dataset_val,
                image_preprocessing_fn=val_image_preprocessing_fn,
                network_fn=network_fn,
                num_epochs=None,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num)

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            print([var.name for var in bn_moving_vars])

            if use_pretrained_model:
                varlist = tf.trainable_variables()
                varlist += bn_moving_vars
                print(varlist)
                # vardict = {v.name[:-2].replace('MobileNet', 'MobilenetV1'): v for v in varlist}
                saver = tf.train.Saver(varlist)
                # saver = tf.train.Saver(vardict)
                if os.path.isfile(FLAGS.checkpoint_path):
                    saver.restore(sess, FLAGS.checkpoint_path)
                    print(
                        '#############################Session restored from pretrained model at {}!###############################'
                        .format(FLAGS.checkpoint_path))
                else:
                    ckpt = tf.train.get_checkpoint_state(
                        checkpoint_dir=FLAGS.checkpoint_path)
                    if ckpt and ckpt.model_checkpoint_path:
                        saver = tf.train.Saver(varlist)
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        print(
                            'Session restored from pretrained degradation model at {}!'
                            .format(ckpt.model_checkpoint_path))
            else:
                varlist = tf.trainable_variables()
                varlist += bn_moving_vars
                saver = tf.train.Saver(varlist)
                ckpt = tf.train.get_checkpoint_state(
                    checkpoint_dir=checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print(
                        '#############################Session restored from trained model at {}!###############################'
                        .format(ckpt.model_checkpoint_path))
                else:
                    raise FileNotFoundError(errno.ENOENT,
                                            os.strerror(errno.ENOENT),
                                            checkpoint_dir)

            saver = tf.train.Saver(tf.trainable_variables() + bn_moving_vars)
            step = 0
            try:
                while not coord.should_stop():
                    start_time = time.time()
                    images, labels = sess.run([images_op, labels_op])
                    _, loss_value = sess.run(
                        [apply_gradient_op, loss_op],
                        feed_dict={
                            images_placeholder: images,
                            labels_placeholder: labels,
                            istraining_placeholder: True
                        })
                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'
                    duration = time.time() - start_time
                    print('Step: {:4d} time: {:.4f} loss: {:.8f}'.format(
                        step, duration, loss_value))

                    if step % FLAGS.val_step == 0:
                        start_time = time.time()
                        images, labels = sess.run([images_op, labels_op])
                        acc, loss_value = sess.run(
                            [acc_op, loss_op],
                            feed_dict={
                                images_placeholder: images,
                                labels_placeholder: labels,
                                istraining_placeholder: False
                            })
                        print(
                            "Step: {:4d} time: {:.4f}, training accuracy: {:.5f}, loss: {:.8f}"
                            .format(step,
                                    time.time() - start_time, acc, loss_value))

                        start_time = time.time()
                        images, labels = sess.run(
                            [val_images_op, val_labels_op])
                        acc, loss_value = sess.run(
                            [acc_op, loss_op],
                            feed_dict={
                                images_placeholder: images,
                                labels_placeholder: labels,
                                istraining_placeholder: False
                            })
                        print(
                            "Step: {:4d} time: {:.4f}, validation accuracy: {:.5f}, loss: {:.8f}"
                            .format(step,
                                    time.time() - start_time, acc, loss_value))

                    # Save a checkpoint and evaluate the model periodically.
                    if step % FLAGS.save_step == 0 or (step +
                                                       1) == FLAGS.max_steps:
                        checkpoint_path = os.path.join(saved_checkpoint_dir,
                                                       'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=step)
                    step += 1
            except tf.errors.OutOfRangeError:
                print('Done training on all the examples')
            finally:
                coord.request_stop()
            coord.request_stop()
            coord.join(threads)
            checkpoint_path = os.path.join(saved_checkpoint_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

    print("done")
Пример #5
0
def get_k_list(alpha):

    model_name = 'inception_v1'

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            dataset = imagenet.get_split('validation', FLAGS.dataset_dir)
            istraining_placeholder = tf.placeholder(tf.bool)
            network_fn = GoogLeNet.GoogLeNet(
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                weight_decay=FLAGS.weight_decay,
                is_training=istraining_placeholder)
            logits_lst = []
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.batch_size * FLAGS.gpu_num,
                       network_fn.default_image_size,
                       network_fn.default_image_size, 3))
            labels_placeholder = tf.placeholder(
                tf.int64,
                shape=(FLAGS.batch_size * FLAGS.gpu_num, dataset.num_classes))
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            logits, end_points, end_points_Ofmap, end_points_Ifmap = network_fn(
                                images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size])
                            logits_lst.append(logits)
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            k_list = {}

            memory_all = 0

            for v in tf.trainable_variables():
                if 'weights' in v.name:
                    #print(sess.run(v).shape)
                    print("v.name: {}".format(v.name))
                    print("v.shape: {}".format(sess.run(v).shape))
                    print('=====================================')
                    shape = sess.run(v).shape
                    n, c, w = shape[3], shape[2], shape[1]
                    memory_all += n * c * w * w * 32
                    k = int(alpha * n * c * w)
                    k_list[v.name] = k
                    pass

                if 'beta' in v.name:
                    print("v.name: {}".format(v.name))
                    print("v.shape: {}".format(sess.run(v).shape))
                    print('+++++++++++++++++++++++++++++++++++++')

            print("memory_all(bits): {:.1f}".format(memory_all))

            return k_list, memory_all
Пример #6
0
    graph = load_graph(args.frozen_model_filename)

    # We can verify that we can access the list of operations in the graph
    for op in graph.get_operations():
        print(op.name)

    # We access the input and output nodes in the graph
    images_placeholder = graph.get_tensor_by_name('prefix/Images_Placeholder:0')
    labels_placeholder = graph.get_tensor_by_name('prefix/Labels_Placeholder:0')

    predicted_labels = graph.get_tensor_by_name('prefix/Predicted_Labels:0')

    right_count_top1_op = graph.get_tensor_by_name('prefix/Right_Count_Top1:0')
    right_count_topk_op = graph.get_tensor_by_name('prefix/Right_Count_Topk:0')

    dataset = imagenet.get_split(
        'validation', '../datasets/imagenet-data/tfrecords')

    with tf.Session(graph=graph) as sess:
        images_op, labels_op = inputs(dataset=dataset, image_preprocessing_fn=preprocessing_factory.get_preprocessing(is_training=False),
                                                     num_epochs=1, batch_size=48)

        init_op = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        total_v = 0.0
        test_correct_num_top1 = 0.0
        test_correct_num_topk = 0.0

        from tqdm import tqdm
Пример #7
0
def evaluate():
    """Eval CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:
        # Get images and labels for CIFAR-10.
        # eval_data = FLAGS.eval_data == 'test'
        # images, labels = vgg.inputs(eval_data=eval_data)
        dataset = imagenet.get_split(
            'validation', '/data/ramyadML/TF-slim-data/imageNet/processed')

        # Creates a TF-Slim DataProvider which reads the dataset in the background
        # during both training and testing.
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=4,
            common_queue_capacity=20 * 32,
            common_queue_min=10 * 32,
            shuffle=True)

        preprocessing_name = 'vgg_16'
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        [image, label] = provider.get(['image', 'label'])
        image = image_preprocessing_fn(image, 224, 224)
        label -= 1

        # batch up some training data
        images, labels = tf.train.batch([image, label],
                                        batch_size=32,
                                        num_threads=4,
                                        capacity=5 * 32)

        print(images.shape)

        images = tf.cast(images, tf.float32)
        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = vgg.inference(images)

        predictions = tf.argmax(logits, 1)
        labels = tf.squeeze(labels)
        top_k_op = slim.metrics.streaming_accuracy(predictions, labels)

        # Calculate predictions.
        # top_k_op = tf.nn.in_top_k(logits, labels, 1)

        # # Restore the moving average version of the learned variables for eval.
        # variable_averages = tf.train.ExponentialMovingAverage(
        #     vgg.MOVING_AVERAGE_DECAY)
        # variables_to_restore = variable_averages.variables_to_restore()
        # saver = tf.train.Saver(variables_to_restore)

        # Save
        list_var_names = [
            'vgg_16/conv1/conv1_1/biases', 'vgg_16/conv1/conv1_1/weights',
            'vgg_16/conv1/conv1_2/biases', 'vgg_16/conv1/conv1_2/weights',
            'vgg_16/conv2/conv2_1/biases', 'vgg_16/conv2/conv2_1/weights',
            'vgg_16/conv2/conv2_2/biases', 'vgg_16/conv2/conv2_2/weights',
            'vgg_16/conv3/conv3_1/biases', 'vgg_16/conv3/conv3_1/weights',
            'vgg_16/conv3/conv3_2/biases', 'vgg_16/conv3/conv3_2/weights',
            'vgg_16/conv3/conv3_3/biases', 'vgg_16/conv3/conv3_3/weights',
            'vgg_16/conv4/conv4_1/biases', 'vgg_16/conv4/conv4_1/weights',
            'vgg_16/conv4/conv4_2/biases', 'vgg_16/conv4/conv4_2/weights',
            'vgg_16/conv4/conv4_3/biases', 'vgg_16/conv4/conv4_3/weights',
            'vgg_16/conv5/conv5_1/biases', 'vgg_16/conv5/conv5_1/weights',
            'vgg_16/conv5/conv5_2/biases', 'vgg_16/conv5/conv5_2/weights',
            'vgg_16/conv5/conv5_3/biases', 'vgg_16/conv5/conv5_3/weights',
            'vgg_16/fc6/biases', 'vgg_16/fc6/weights', 'vgg_16/fc7/biases',
            'vgg_16/fc7/weights', 'vgg_16/fc8/biases', 'vgg_16/fc8/weights',
            'vgg_16/conv1/conv1_1/mask', 'vgg_16/conv1/conv1_2/mask',
            'vgg_16/conv2/conv2_1/mask', 'vgg_16/conv2/conv2_2/mask',
            'vgg_16/conv3/conv3_1/mask', 'vgg_16/conv3/conv3_2/mask',
            'vgg_16/conv3/conv3_3/mask', 'vgg_16/conv4/conv4_1/mask',
            'vgg_16/conv4/conv4_2/mask', 'vgg_16/conv4/conv4_3/mask',
            'vgg_16/conv5/conv5_1/mask', 'vgg_16/conv5/conv5_2/mask',
            'vgg_16/conv5/conv5_3/mask', 'vgg_16/fc6/mask', 'vgg_16/fc7/mask',
            'vgg_16/fc8/mask'
        ]

        var_list_to_restore = []

        for name in list_var_names:
            var_list_to_restore = var_list_to_restore + tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, name)
        saver = tf.train.Saver(var_list_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        while True:
            eval_once(saver, summary_writer, top_k_op, summary_op)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
Пример #8
0
def main(config, RANDOM_SEED, LOG_DIR, TRAIN_NUM, BATCH_SIZE, LEARNING_RATE,
         DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE, BETA, K, D, SAVE_PERIOD,
         SUMMARY_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)

    # >>>>>>> DATASET
    train_dataset = imagenet.get_split('train', 'datasets/ILSVRC2012')
    valid_dataset = imagenet.get_split('validation', 'datasets/ILSVRC2012')
    train_ims, _ = _build_batch(train_dataset, BATCH_SIZE, 4)
    valid_ims, _ = _build_batch(valid_dataset, 4, 1)

    # >>>>>>> MODEL
    with tf.variable_scope('train'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        with tf.variable_scope('params') as params:
            pass
        net = VQVAE(learning_rate, global_step, BETA, train_ims, K, D,
                    _imagenet_arch, params, True)

    with tf.variable_scope('valid'):
        params.reuse_variables()
        valid_net = VQVAE(None, None, BETA, valid_ims, K, D, _imagenet_arch,
                          params, False)

    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        tf.summary.scalar('recon', net.recon)
        tf.summary.scalar('vq', net.vq)
        tf.summary.scalar('commit', BETA * net.commit)
        tf.summary.scalar('nll', tf.reduce_mean(net.nll))
        tf.summary.image('origin', train_ims, max_outputs=4)
        tf.summary.image('recon', net.p_x_z, max_outputs=4)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        extended_summary_op = tf.summary.merge([
            tf.summary.scalar('valid_loss', valid_net.loss),
            tf.summary.scalar('valid_recon', valid_net.recon),
            tf.summary.scalar('valid_vq', valid_net.vq),
            tf.summary.scalar('valid_commit', BETA * valid_net.commit),
            tf.summary.scalar('valid_nll', tf.reduce_mean(valid_net.nll)),
            tf.summary.image('valid_origin', valid_ims, max_outputs=4),
            tf.summary.image('valid_recon', valid_net.p_x_z, max_outputs=4),
        ])
    # <<<<<<<<<<

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    try:
        # Start Queueing
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)
        for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
            it, loss, _ = sess.run([global_step, net.loss, net.train_op])

            if (it % SAVE_PERIOD == 0):
                net.save(sess, LOG_DIR, step=it)

            if (it % SUMMARY_PERIOD == 0):
                tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
                summary = sess.run(summary_op)
                summary_writer.add_summary(summary, it)

            if (it % (SUMMARY_PERIOD * 2) == 0):  #Extended Summary
                summary = sess.run(extended_summary_op)
                summary_writer.add_summary(summary, it)

    except Exception as e:
        coord.request_stop(e)
    finally:
        net.save(sess, LOG_DIR)

        coord.request_stop()
        coord.join(threads)
Пример #9
0
def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE,
                LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE,
                GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS,
                SUMMARY_PERIOD, SAVE_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL), 'pixelcnn')
    # >>>>>>> DATASET
    train_dataset = imagenet.get_split('train', 'datasets/ILSVRC2012')
    ims, labels = _build_batch(train_dataset, BATCH_SIZE, 4)
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        vq_net = VQVAE(None, None, BETA, ims, K, D, _imagenet_arch, params,
                       False)
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        net = PixelCNN(learning_rate, global_step, GRAD_CLIP,
                       vq_net.k.get_shape()[1], vq_net.embeds, K, D, 1000,
                       NUM_LAYERS, NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        sample_images = tf.placeholder(tf.float32, [None, 128, 128, 3])
        sample_summary_op = tf.summary.image('samples',
                                             sample_images,
                                             max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess, MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
            batch_xs, batch_ys = sess.run([vq_net.k, labels])
            it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })

            if (it % SAVE_PERIOD == 0):
                net.save(sess, LOG_DIR, step=it)
                sampled_zs, log_probs = net.sample_from_prior(
                    sess, np.random.randint(0, 1000, size=(10, )), 2)
                sampled_ims = sess.run(vq_net.gen,
                                       feed_dict={vq_net.latent: sampled_zs})
                summary_writer.add_summary(
                    sess.run(sample_summary_op,
                             feed_dict={sample_images: sampled_ims}), it)

            if (it % SUMMARY_PERIOD == 0):
                tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
                summary = sess.run(summary_op,
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })
                summary_writer.add_summary(summary, it)

    except Exception as e:
        coord.request_stop(e)
    finally:
        net.save(sess, LOG_DIR)

        coord.request_stop()
        coord.join(threads)
    return images, images_raw, labels


flowers_data_dir = "/home/robin/Dataset/flowers"
mnist_data_dir = "/home/robin/Dataset/mnist"
cifar10_data_dir = "/home/robin/Dataset/cifar10"
imagenet_data_dir = "/home/robin/Dataset/imaget/output_tfrecord"

if __name__ == "__main__":
    with tf.Graph().as_default():
        #         dataset = flowers.get_split('train', flowers_data_dir) #load_batch(dataset,height=224, width=224)
        #         dataset = mnist.get_split("train",mnist_data_dir) #load_batch(dataset,height=28, width=28)
        #         dataset = cifar10.get_split("train",cifar10_data_dir) #load_batch(dataset,height=32, width=32)
        dataset = imagenet.get_split(
            'validation',
            imagenet_data_dir)  #load_batch(dataset,height=224, width=224)

        batch_image, batch_raw_image, batch_labels = load_batch(
            dataset, height=224, width=224, is_training=True)
        one_hot_batch_labels = slim.one_hot_encoding(batch_labels,
                                                     dataset.num_classes)
        print(batch_image, batch_raw_image, batch_labels)

        data_provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset, common_queue_capacity=32, common_queue_min=1)
        image, label = data_provider.get(['image', 'label'])
        one_hot_labels = slim.one_hot_encoding(label, dataset.num_classes)

        with tf.Session() as sess:
            with slim.queues.QueueRunners(sess):
Пример #11
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # images, labels = vgg.distorted_inputs()
    dataset = imagenet.get_split('train', '/data/ramyadML/TF-slim-data/imageNet/processed')

    # Creates a TF-Slim DataProvider which reads the dataset in the background
    # during both training and testing.
    provider = slim.dataset_data_provider.DatasetDataProvider(dataset,
                                                              num_readers=4,
                                                              common_queue_capacity=20*32,
                                                              common_queue_min=10*32,
                                                              shuffle=True)


    preprocessing_name = 'vgg_16'
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                            preprocessing_name,
                            is_training=True)

    [image, label] = provider.get(['image', 'label'])
    image = image_preprocessing_fn(image, 224, 224)
    label -= 1

    # batch up some training data
    images, labels = tf.train.batch([image, label], 
                                    batch_size=32,
                                    num_threads=4,
                                    capacity=5*32)

    print (images.shape)


    images = tf.cast(images, tf.float32)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = vgg.inference(images)

    print ("logits shape:", logits.shape)
    # Calculate loss.
    print ("label shape", labels.shape)
    # Calculate loss.
    loss = vgg.loss(logits, labels)


    # Save
    list_var_names = [  'vgg_16/conv1/conv1_1/biases',
                    	'vgg_16/conv1/conv1_1/weights',
                    	'vgg_16/conv1/conv1_2/biases',
                    	'vgg_16/conv1/conv1_2/weights',
                    	'vgg_16/conv2/conv2_1/biases',
                    	'vgg_16/conv2/conv2_1/weights',
                    	'vgg_16/conv2/conv2_2/biases',
                    	'vgg_16/conv2/conv2_2/weights',
                    	'vgg_16/conv3/conv3_1/biases',
                    	'vgg_16/conv3/conv3_1/weights',
                    	'vgg_16/conv3/conv3_2/biases',
                    	'vgg_16/conv3/conv3_2/weights',
                    	'vgg_16/conv3/conv3_3/biases',
                    	'vgg_16/conv3/conv3_3/weights',
                    	'vgg_16/conv4/conv4_1/biases',
                    	'vgg_16/conv4/conv4_1/weights',
                    	'vgg_16/conv4/conv4_2/biases',
                    	'vgg_16/conv4/conv4_2/weights',
                    	'vgg_16/conv4/conv4_3/biases',
                    	'vgg_16/conv4/conv4_3/weights',
                    	'vgg_16/conv5/conv5_1/biases',
                    	'vgg_16/conv5/conv5_1/weights',
                    	'vgg_16/conv5/conv5_2/biases',
                    	'vgg_16/conv5/conv5_2/weights',
                    	'vgg_16/conv5/conv5_3/biases',
                    	'vgg_16/conv5/conv5_3/weights',
                    	'vgg_16/fc6/biases',
                    	'vgg_16/fc6/weights',
                    	'vgg_16/fc7/biases',
                    	'vgg_16/fc7/weights',
                    	'vgg_16/fc8/biases',
                    	'vgg_16/fc8/weights']

    var_list_to_restore = []
 
    for name in list_var_names:
        var_list_to_restore = var_list_to_restore + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name)

    saver = tf.train.Saver(var_list_to_restore)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = vgg.train(loss, global_step)

    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    # Use the pruning_obj to add ops to the training graph to update the masks
    # The conditional_mask_update_op will update the masks only when the
    # training step is in [begin_pruning_step, end_pruning_step] specified in
    # the pruning spec proto
    mask_update_op = pruning_obj.conditional_mask_update_op()

    # Use the pruning_obj to add summaries to the graph to track the sparsity
    # of each of the layers
    pruning_obj.add_pruning_summaries()


    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = 128
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print(format_str % (datetime.datetime.now(), self._step, loss_value,
                              examples_per_sec, sec_per_batch))


    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
                config=tf.ConfigProto(
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:

      saver.restore(mon_sess,"trained_weights/vgg_16.ckpt")
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
        # Update the masks
        mon_sess.run(mask_update_op)