예제 #1
0
def test():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images_placeholder, labels_placeholder, dropout_placeholder = placeholder_inputs(
            FLAGS.batch_size)
        images, labels = input_data.inputs(
            filename='./tmp/tfrecords/test.tfrecords',
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            num_threads=5,
            imshape=[128, 128, 3])

        logits = convnet_model.inference(images_placeholder,
                                         dropout_placeholder)
        acc, predictions_with_labels = convnet_model.testing(
            logits, labels_placeholder)
        saver = tf.train.Saver()
        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement, ))
        sess.run(init)

        ckpt = tf.train.get_checkpoint_state(
            checkpoint_dir=FLAGS.checkpoint_dir)
        print ckpt.model_checkpoint_path
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('Restored!')

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        # The order of category names here does not matter
        names_categories = ['adafor', 'brkt2', 'brng', 'fl3', 'flng']
        evaluator = classification_performance_evaluator(names_categories)
        while not coord.should_stop():

            for step in xrange(FLAGS.max_steps_test):
                te_images, te_labels = sess.run([images, labels])
                print te_images.shape
                print te_labels.shape
                te_feed = {
                    images_placeholder: te_images,
                    labels_placeholder: te_labels,
                    dropout_placeholder: 1.0
                }
                te_acc, te_predictions_with_labels = sess.run(
                    [acc, predictions_with_labels], feed_dict=te_feed)
                print('Step ' + str(step) + ' Testing Accuracy: ' +
                      str(te_acc))
                predictions = te_predictions_with_labels[0]
                _labels = te_predictions_with_labels[1]
                print _labels
                evaluator.update(_labels, predictions)
            coord.request_stop()

        sess.close()
        evaluator.print_performance()
예제 #2
0
def run_gpu_eval(use_compression=False,
                 use_quantization=False,
                 compute_energy=False,
                 use_pretrained_model=True,
                 epoch_num=0):
    from functools import reduce
    module_name = 'inception_v1'
    checkpoint_dir = 'checkpoint/{}_{}_{}'.format(module_name, epoch_num,
                                                  FLAGS.alpha)

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            dataset = imagenet.get_split('validation', FLAGS.dataset_dir)
            istraining_placeholder = tf.placeholder(tf.bool)
            network_fn = GoogLeNet.GoogLeNet(
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                weight_decay=FLAGS.weight_decay,
                is_training=istraining_placeholder)
            logits_lst = []
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.batch_size * FLAGS.gpu_num,
                       network_fn.default_image_size,
                       network_fn.default_image_size, 3))
            labels_placeholder = tf.placeholder(
                tf.int64,
                shape=(FLAGS.batch_size * FLAGS.gpu_num, dataset.num_classes))
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            logits, end_points, end_points_Ofmap, end_points_Ifmap = network_fn(
                                images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size])
                            logits_lst.append(logits)
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

            image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                is_training=False)

            logits_op = tf.concat(logits_lst, 0)
            right_count_top1_op = tf.reduce_sum(
                tf.cast(
                    tf.equal(tf.argmax(tf.nn.softmax(logits_op), axis=1),
                             tf.argmax(labels_placeholder, axis=1)), tf.int32))
            right_count_topk_op = tf.reduce_sum(
                tf.cast(
                    tf.nn.in_top_k(tf.nn.softmax(logits_op),
                                   tf.argmax(labels_placeholder, axis=1), 5),
                    tf.int32))

            images_op, labels_op = input_data.inputs(
                dataset=dataset,
                image_preprocessing_fn=image_preprocessing_fn,
                network_fn=network_fn,
                num_epochs=1,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num)

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            print([var.name for var in bn_moving_vars])

            if use_pretrained_model:
                varlist = tf.trainable_variables()
                varlist += bn_moving_vars
                print(varlist)
                saver = tf.train.Saver(varlist)
                # saver = tf.train.Saver(vardict)
                if os.path.isfile(FLAGS.checkpoint_path):
                    saver.restore(sess, FLAGS.checkpoint_path)
                    print(
                        '#############################Session restored from pretrained model at {}!###############################'
                        .format(FLAGS.checkpoint_path))
                else:
                    ckpt = tf.train.get_checkpoint_state(
                        checkpoint_dir=FLAGS.checkpoint_path)
                    if ckpt and ckpt.model_checkpoint_path:
                        saver = tf.train.Saver(varlist)
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        print(
                            'Session restored from pretrained degradation model at {}!'
                            .format(ckpt.model_checkpoint_path))
            else:
                varlist = tf.trainable_variables()
                varlist += bn_moving_vars
                saver = tf.train.Saver(varlist)
                ckpt = tf.train.get_checkpoint_state(
                    checkpoint_dir=checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print(
                        '#############################Session restored from trained model at {}!###############################'
                        .format(ckpt.model_checkpoint_path))
                else:
                    raise FileNotFoundError(errno.ENOENT,
                                            os.strerror(errno.ENOENT),
                                            checkpoint_dir)

            mat_eng = matlab.engine.start_matlab()
            seed = 500
            alpha = FLAGS.alpha
            memory = 0
            for v in tf.trainable_variables() + bn_moving_vars:
                if 'weights' in v.name:
                    memory += np.prod(sess.run(v).shape)
                    print("weights.name: {}".format(v.name))
                    print("weights.shape: {}".format(sess.run(v).shape))
                    if use_compression:
                        weights = np.transpose(sess.run(v), (3, 2, 1, 0))
                        shape = weights.shape
                        n, c, w = shape[0], shape[1], shape[2]
                        k = int(alpha * n * c * w)
                        weight_clustered, mse = cluster_conv(weights, k, seed)
                        weight_clustered = np.transpose(
                            weight_clustered, (3, 2, 1, 0))
                        sess.run(v.assign(weight_clustered))
                        print("weight_clustered shape: {}".format(
                            weight_clustered.shape))
                        print("mse: {}".format(mse))
                        seed += 1
                    if use_quantization:
                        weights = np.transpose(sess.run(v), (3, 2, 1, 0))
                        shape = weights.shape
                        weight_quantized = mat_eng.get_fi(
                            matlab.double(weights.tolist()), FLAGS.bitwidth,
                            FLAGS.bitwidth -
                            FLAGS.bitwidth_minus_fraction_length)
                        weight_quantized = np.asarray(
                            weight_quantized).reshape(shape).astype('float32')
                        weight_quantized = np.transpose(
                            weight_quantized, (3, 2, 1, 0))
                        sess.run(v.assign(weight_quantized))
                        print("weight_quantized shape: {}".format(
                            weight_quantized.shape))
                    print('=====================================')

                if any(x in v.name for x in ['beta']):
                    memory += np.prod(sess.run(v).shape)
                    print("beta.name: {}".format(v.name))
                    print("beta.shape: {}".format(sess.run(v).shape))
                    if use_quantization:
                        weights = sess.run(v)
                        shape = weights.shape
                        weight_quantized = mat_eng.get_fi(
                            matlab.double(weights.tolist()), FLAGS.bn_bitwidth,
                            FLAGS.bn_bitwidth -
                            FLAGS.bitwidth_minus_fraction_length)
                        weight_quantized = np.asarray(
                            weight_quantized).reshape(shape).astype('float32')
                        sess.run(v.assign(weight_quantized))
                        print("beta_quantized shape: {}".format(
                            weight_quantized.shape))
                    print('+++++++++++++++++++++++++++++++++++++')

            checkpoint_path = os.path.join(checkpoint_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=0)
            print(
                "############################################### MEMORY IS {} ###############################################"
                .format(memory))

            if compute_energy:
                weights_dict = {}
                for v in tf.trainable_variables():
                    if 'weights' in v.name:
                        vname = "_".join(v.name.split('/')[1:-1])
                        print("v.name: {}".format(vname))
                        print("v.shape: {}".format(sess.run(v).shape))
                        #weights = np.transpose(sess.run(v), (3, 2, 1, 0))
                        weights = sess.run(v)
                        print("v.nzeros: {}".format(
                            np.count_nonzero(weights == 0)))
                        weights_dict[vname] = [
                            reduce(lambda x, y: x * y, weights.shape) *
                            (1 - FLAGS.alpha), weights.shape
                        ]
                        print('=====================================')

            total_v = 0.0
            test_correct_num_top1 = 0.0
            test_correct_num_topk = 0.0

            from tqdm import tqdm

            pbar = tqdm(total=dataset.num_samples //
                        (FLAGS.gpu_num * FLAGS.batch_size), )
            i = 1
            model_params_dict = {}
            try:
                while not coord.should_stop():
                    pbar.update(1)
                    images, labels = sess.run([images_op, labels_op])

                    right_count_top1, right_count_topk = sess.run(
                        [right_count_top1_op, right_count_topk_op],
                        feed_dict={
                            images_placeholder: images,
                            labels_placeholder: labels,
                            istraining_placeholder: False
                        })

                    end_points_Ofmap_dict, end_points_Ifmap_dict = sess.run(
                        [end_points_Ofmap, end_points_Ifmap],
                        feed_dict={
                            images_placeholder: images,
                            labels_placeholder: labels,
                            istraining_placeholder: False
                        })

                    test_correct_num_top1 += right_count_top1
                    test_correct_num_topk += right_count_topk
                    total_v += labels.shape[0]

                    if compute_energy:
                        keys = list(end_points_Ifmap_dict.keys())
                        if i == 1:
                            for k in keys:
                                model_params_dict[k] = {}
                                model_params_dict[k][
                                    "IfMap_Shape"] = end_points_Ifmap_dict[
                                        k].shape
                                model_params_dict[k][
                                    "IfMap_nZeros"] = np.count_nonzero(
                                        end_points_Ifmap_dict[k] == 0)

                                model_params_dict[k][
                                    "Filter_Shape"] = weights_dict[k][1]
                                model_params_dict[k]["Filter_nZeros"] = int(
                                    weights_dict[k][0])

                                model_params_dict[k][
                                    "OfMap_Shape"] = end_points_Ofmap_dict[
                                        k].shape
                                model_params_dict[k][
                                    "OfMap_nZeros"] = np.count_nonzero(
                                        end_points_Ofmap_dict[k] == 0)
                                print("Layer Name: {}".format(k))
                                print("IfMap Shape: {}".format(
                                    end_points_Ifmap_dict[k].shape))
                                print("IfMap nZeros: {:.4e}".format(
                                    np.count_nonzero(
                                        end_points_Ifmap_dict[k] == 0)))
                                print("IfMap nZeros Avg: {:.4e}".format(
                                    model_params_dict[k]["IfMap_nZeros"]))
                                print("Filter Shape: {}".format(
                                    weights_dict[k][1]))
                                print("Filter nZeros: {:.4e}".format(
                                    int(weights_dict[k][0])))
                                print("OfMap Shape: {}".format(
                                    end_points_Ofmap_dict[k].shape))
                                print("OfMap nZeros: {:.4e}".format(
                                    np.count_nonzero(
                                        end_points_Ofmap_dict[k] == 0)))
                                print("OfMap nZeros Avg: {:.4e}".format(
                                    model_params_dict[k]["OfMap_nZeros"]))
                                print(
                                    '=========================================================================='
                                )
                        else:
                            for k in keys:
                                model_params_dict[k]["IfMap_nZeros"] = (
                                    model_params_dict[k]["IfMap_nZeros"] +
                                    np.count_nonzero(
                                        end_points_Ifmap_dict[k] == 0) /
                                    (i - 1)) * (i - 1) / i
                                model_params_dict[k]["OfMap_nZeros"] = (
                                    model_params_dict[k]["OfMap_nZeros"] +
                                    np.count_nonzero(
                                        end_points_Ofmap_dict[k] == 0) /
                                    (i - 1)) * (i - 1) / i
                        i += 1
            except tf.errors.OutOfRangeError:
                print('Done testing on all the examples')
            finally:
                coord.request_stop()
                if compute_energy:
                    import pickle
                    with open('model_params_dict.pkl', 'wb') as f:
                        pickle.dump(model_params_dict, f,
                                    pickle.HIGHEST_PROTOCOL)
                    with open('GoogLeNet_Pruned_{}.txt'.format(FLAGS.alpha),
                              'w') as wf:
                        for k in keys:
                            wf.write("Layer Name: {}\n".format(k))

                            wf.write("IfMap Shape: {}\n".format(
                                model_params_dict[k]["IfMap_Shape"]))
                            wf.write("IfMap nZeros: {:.4e}\n".format(
                                model_params_dict[k]["IfMap_nZeros"]))

                            wf.write("Filter Shape: {}\n".format(
                                model_params_dict[k]["Filter_Shape"]))
                            wf.write("Filter nZeros: {:.4e}\n".format(
                                model_params_dict[k]["Filter_nZeros"]))

                            wf.write("OfMap Shape: {}\n".format(
                                model_params_dict[k]["OfMap_Shape"]))
                            wf.write("OfMap nZeros: {:.4e}\n".format(
                                model_params_dict[k]["OfMap_nZeros"]))
                            wf.write(
                                '==========================================================================\n'
                            )
            coord.join(threads)
            print('Test acc top1:', test_correct_num_top1 / total_v,
                  'Test_correct_num top1:', test_correct_num_top1, 'Total_v:',
                  total_v)
            print('Test acc topk:', test_correct_num_topk / total_v,
                  'Test_correct_num topk:', test_correct_num_topk, 'Total_v:',
                  total_v)

            isCompression = lambda bool: "Compression_" if bool else "NoCompression_"
            isQuantization = lambda bool: "Quantization_" if bool else "NoQuantization_"
            with open(
                    '{}_{}_{}_{}_evaluation.txt'.format(
                        isCompression(use_compression),
                        isQuantization(use_quantization), epoch_num,
                        FLAGS.alpha), 'w') as wf:
                wf.write(
                    'test acc top1:{}\ttest_correct_num top1:{}\ttotal_v:{}\n'.
                    format(test_correct_num_top1 / total_v,
                           test_correct_num_top1, total_v))
                wf.write(
                    'test acc topk:{}\ttest_correct_num topk:{}\ttotal_v:{}\n'.
                    format(test_correct_num_topk / total_v,
                           test_correct_num_topk, total_v))

    print("done")
예제 #3
0
def run_gpu_train(use_pretrained_model, epoch_num):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    module_name = 'inception_v1'
    checkpoint_dir = 'checkpoint/{}_{}_{}'.format(module_name, epoch_num - 1,
                                                  FLAGS.alpha)

    saved_checkpoint_dir = 'checkpoint/{}_{}_{}'.format(
        module_name, epoch_num, FLAGS.alpha)

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if not os.path.exists(saved_checkpoint_dir):
        os.makedirs(saved_checkpoint_dir)
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            dataset = imagenet.get_split('train', FLAGS.dataset_dir)
            dataset_val = imagenet.get_split('validation', FLAGS.dataset_dir)
            global_step = slim.create_global_step()
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            istraining_placeholder = tf.placeholder(tf.bool)
            network_fn = GoogLeNet.GoogLeNet(
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                weight_decay=FLAGS.weight_decay,
                is_training=istraining_placeholder)
            tower_grads = []
            logits_lst = []
            losses_lst = []
            opt = _configure_optimizer(learning_rate)
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.batch_size * FLAGS.gpu_num,
                       network_fn.default_image_size,
                       network_fn.default_image_size, 3))
            labels_placeholder = tf.placeholder(
                tf.int64,
                shape=(FLAGS.batch_size * FLAGS.gpu_num, dataset.num_classes))
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            logits, _, _, _ = network_fn(
                                images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size])
                            logits_lst.append(logits)
                            loss = tower_loss_xentropy_dense(
                                logits, labels_placeholder[gpu_index *
                                                           FLAGS.batch_size:
                                                           (gpu_index + 1) *
                                                           FLAGS.batch_size])
                            losses_lst.append(loss)
                            # varlist = [v for v in tf.trainable_variables() if any(x in v.name for x in ["logits"])]
                            varlist = tf.trainable_variables()
                            #print([v.name for v in varlist])
                            grads = opt.compute_gradients(loss, varlist)
                            tower_grads.append(grads)
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

            image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                is_training=True)
            val_image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                is_training=False)

            loss_op = tf.reduce_mean(losses_lst, name='softmax')
            logits_op = tf.concat(logits_lst, 0)
            acc_op = accuracy(logits_op, labels_placeholder)
            grads = average_gradients(tower_grads)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print(update_ops)
            with tf.control_dependencies([tf.group(*update_ops)]):
                apply_gradient_op = opt.apply_gradients(
                    grads, global_step=global_step)

            images_op, labels_op = input_data.inputs(
                dataset=dataset,
                image_preprocessing_fn=image_preprocessing_fn,
                network_fn=network_fn,
                num_epochs=1,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num)
            val_images_op, val_labels_op = input_data.inputs(
                dataset=dataset_val,
                image_preprocessing_fn=val_image_preprocessing_fn,
                network_fn=network_fn,
                num_epochs=None,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num)

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            print([var.name for var in bn_moving_vars])

            if use_pretrained_model:
                varlist = tf.trainable_variables()
                varlist += bn_moving_vars
                print(varlist)
                # vardict = {v.name[:-2].replace('MobileNet', 'MobilenetV1'): v for v in varlist}
                saver = tf.train.Saver(varlist)
                # saver = tf.train.Saver(vardict)
                if os.path.isfile(FLAGS.checkpoint_path):
                    saver.restore(sess, FLAGS.checkpoint_path)
                    print(
                        '#############################Session restored from pretrained model at {}!###############################'
                        .format(FLAGS.checkpoint_path))
                else:
                    ckpt = tf.train.get_checkpoint_state(
                        checkpoint_dir=FLAGS.checkpoint_path)
                    if ckpt and ckpt.model_checkpoint_path:
                        saver = tf.train.Saver(varlist)
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        print(
                            'Session restored from pretrained degradation model at {}!'
                            .format(ckpt.model_checkpoint_path))
            else:
                varlist = tf.trainable_variables()
                varlist += bn_moving_vars
                saver = tf.train.Saver(varlist)
                ckpt = tf.train.get_checkpoint_state(
                    checkpoint_dir=checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print(
                        '#############################Session restored from trained model at {}!###############################'
                        .format(ckpt.model_checkpoint_path))
                else:
                    raise FileNotFoundError(errno.ENOENT,
                                            os.strerror(errno.ENOENT),
                                            checkpoint_dir)

            saver = tf.train.Saver(tf.trainable_variables() + bn_moving_vars)
            step = 0
            try:
                while not coord.should_stop():
                    start_time = time.time()
                    images, labels = sess.run([images_op, labels_op])
                    _, loss_value = sess.run(
                        [apply_gradient_op, loss_op],
                        feed_dict={
                            images_placeholder: images,
                            labels_placeholder: labels,
                            istraining_placeholder: True
                        })
                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'
                    duration = time.time() - start_time
                    print('Step: {:4d} time: {:.4f} loss: {:.8f}'.format(
                        step, duration, loss_value))

                    if step % FLAGS.val_step == 0:
                        start_time = time.time()
                        images, labels = sess.run([images_op, labels_op])
                        acc, loss_value = sess.run(
                            [acc_op, loss_op],
                            feed_dict={
                                images_placeholder: images,
                                labels_placeholder: labels,
                                istraining_placeholder: False
                            })
                        print(
                            "Step: {:4d} time: {:.4f}, training accuracy: {:.5f}, loss: {:.8f}"
                            .format(step,
                                    time.time() - start_time, acc, loss_value))

                        start_time = time.time()
                        images, labels = sess.run(
                            [val_images_op, val_labels_op])
                        acc, loss_value = sess.run(
                            [acc_op, loss_op],
                            feed_dict={
                                images_placeholder: images,
                                labels_placeholder: labels,
                                istraining_placeholder: False
                            })
                        print(
                            "Step: {:4d} time: {:.4f}, validation accuracy: {:.5f}, loss: {:.8f}"
                            .format(step,
                                    time.time() - start_time, acc, loss_value))

                    # Save a checkpoint and evaluate the model periodically.
                    if step % FLAGS.save_step == 0 or (step +
                                                       1) == FLAGS.max_steps:
                        checkpoint_path = os.path.join(saved_checkpoint_dir,
                                                       'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=step)
                    step += 1
            except tf.errors.OutOfRangeError:
                print('Done training on all the examples')
            finally:
                coord.request_stop()
            coord.request_stop()
            coord.join(threads)
            checkpoint_path = os.path.join(saved_checkpoint_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

    print("done")
예제 #4
0
def run_testing_multi_models(is_training=False):
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.gpu_num * FLAGS.batch_size, None, None,
                       FLAGS.nchannel))
            labels_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.gpu_num * FLAGS.batch_size, FLAGS.num_classes))
            isTraining_placeholder = tf.placeholder(tf.bool)

            from collections import defaultdict
            logits_budget_images_lst_dct = defaultdict(list)
            loss_budget_images_lst_dct = defaultdict(list)

            logits_lst = []
            losses_lst = []
            model_dict = {}
            model_name_lst = [
                'resnet_v1_50', 'resnet_v2_50', 'mobilenet_v1',
                'mobilenet_v1_075'
            ]
            for model_name in model_name_lst:
                model_dict[model_name] = nets_factory.get_network_fn(
                    model_name,
                    num_classes=FLAGS.num_classes,
                    weight_decay=FLAGS.weight_decay,
                    is_training=True)
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            X = images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size]
                            loss_budget = 0.0
                            logits_budget = tf.zeros(
                                [FLAGS.batch_size, FLAGS.num_classes])

                            for model_name in model_name_lst:
                                print(model_name)
                                print(tf.trainable_variables())
                                logits, _ = model_dict[model_name](X)
                                logits_budget += logits
                                loss = tf.reduce_mean(
                                    tf.nn.sigmoid_cross_entropy_with_logits(
                                        logits=logits,
                                        labels=labels_placeholder[
                                            gpu_index *
                                            FLAGS.batch_size:(gpu_index + 1) *
                                            FLAGS.batch_size]))
                                loss_budget += loss
                                logits_budget_images_lst_dct[
                                    model_name].append(logits)
                                loss_budget_images_lst_dct[model_name].append(
                                    loss)
                            logits_budget = tf.divide(logits_budget, 4.0,
                                                      'LogitsBudgetMean')
                            logits_lst.append(logits_budget)
                            losses_lst.append(loss_budget)
                            # varlist = [v for v in tf.trainable_variables() if any(x in v.name for x in ["logits"])]
                            varlist = tf.trainable_variables()
                            print([v.name for v in varlist])
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()
            loss_op = tf.reduce_mean(losses_lst)
            logits_op = tf.concat(logits_lst, 0)

            logits_op_lst = []
            for model_name in model_name_lst:
                logits_op_lst.append(
                    tf.concat(logits_budget_images_lst_dct[model_name],
                              axis=0))

            train_files = [
                os.path.join(FLAGS.train_images_files_dir, f)
                for f in os.listdir(FLAGS.train_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            val_files = [
                os.path.join(FLAGS.val_images_files_dir, f)
                for f in os.listdir(FLAGS.val_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            test_files = [
                os.path.join(FLAGS.test_images_files_dir, f)
                for f in os.listdir(FLAGS.test_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            print(
                '#############################Reading from files###############################'
            )
            print(train_files)
            print(val_files)
            print(test_files)

            if is_training:
                images_op, labels_op = input_data.inputs(
                    filenames=train_files,
                    batch_size=FLAGS.batch_size * FLAGS.gpu_num,
                    num_epochs=1,
                    num_threads=FLAGS.num_threads,
                    num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                    shuffle=False)
            else:
                images_op, labels_op = input_data.inputs(
                    filenames=test_files,
                    batch_size=FLAGS.batch_size * FLAGS.gpu_num,
                    num_epochs=1,
                    num_threads=FLAGS.num_threads,
                    num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                    shuffle=False)

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print(
                '----------------------------Trainable Variables-----------------------------------------'
            )
            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            saver = tf.train.Saver(tf.trainable_variables() + bn_moving_vars)
            ckpt = tf.train.get_checkpoint_state(
                checkpoint_dir=FLAGS.checkpoint_dir)
            print(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Session restored from pretrained budget model at {}!'.
                      format(ckpt.model_checkpoint_path))
            else:
                raise FileNotFoundError(errno.ENOENT,
                                        os.strerror(errno.ENOENT),
                                        FLAGS.checkpoint_dir)

            model_name_lst += ['ensemble']
            loss_budget_lst = []
            pred_probs_lst_lst = [[] for _ in xrange(len(model_name_lst))]
            gt_lst = []
            try:
                while not coord.should_stop():
                    images, labels = sess.run([images_op, labels_op])
                    # write_video(videos, labels)
                    gt_lst.append(labels)
                    value_lst = sess.run([loss_op, logits_op] + logits_op_lst,
                                         feed_dict={
                                             images_placeholder: images,
                                             labels_placeholder: labels,
                                             isTraining_placeholder: True
                                         })
                    print(labels.shape)
                    loss_budget_lst.append(value_lst[0])
                    for i in xrange(len(model_name_lst)):
                        pred_probs_lst_lst[i].append(value_lst[i + 1])

            except tf.errors.OutOfRangeError:
                print('Done testing on all the examples')
            finally:
                coord.request_stop()

            gt_mat = np.concatenate(gt_lst, axis=0)
            n_examples, n_labels = gt_mat.shape
            for i in xrange(len(model_name_lst)):
                save_dir = os.path.join(FLAGS.checkpoint_dir, 'evaluation')
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                isTraining = lambda bool: "training" if bool else "validation"
                with open(
                        os.path.join(
                            save_dir, '{}_class_scores_{}.txt'.format(
                                model_name_lst[i], isTraining(is_training))),
                        'w') as wf:
                    pred_probs_mat = np.concatenate(pred_probs_lst_lst[i],
                                                    axis=0)
                    wf.write('# Examples = {}\n'.format(n_examples))
                    wf.write('# Labels = {}\n'.format(n_labels))
                    wf.write('Average Loss = {}\n'.format(
                        np.mean(loss_budget_lst)))
                    wf.write("Macro MAP = {:.2f}\n".format(
                        100 * average_precision_score(
                            gt_mat, pred_probs_mat, average='macro')))
                    cmap_stats = average_precision_score(gt_mat,
                                                         pred_probs_mat,
                                                         average=None)
                    attr_id_to_name, attr_id_to_idx = load_attributes()
                    idx_to_attr_id = {v: k for k, v in attr_id_to_idx.items()}
                    wf.write('\t'.join([
                        'attribute_id', 'attribute_name', 'num_occurrences',
                        'ap'
                    ]) + '\n')
                    for idx in range(n_labels):
                        attr_id = idx_to_attr_id[idx]
                        attr_name = attr_id_to_name[attr_id]
                        attr_occurrences = np.sum(gt_mat, axis=0)[idx]
                        ap = cmap_stats[idx]
                        wf.write('{}\t{}\t{}\t{}\n'.format(
                            attr_id, attr_name, attr_occurrences, ap * 100.0))

            coord.join(threads)
            sess.close()

    print("done")
예제 #5
0
def run_training():
    # Create model directory
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    use_pretrained_model = True

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.gpu_num * FLAGS.batch_size, None, None,
                       FLAGS.nchannel))
            labels_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.gpu_num * FLAGS.batch_size, FLAGS.num_classes))
            istraining_placeholder = tf.placeholder(tf.bool)
            tower_grads = []
            logits_lst = []
            losses_lst = []
            learning_rate = tf.train.exponential_decay(
                0.001,  # Base learning rate.
                global_step,  # Current index into the dataset.
                5000,  # Decay step.
                0.96,  # Decay rate.
                staircase=True)
            # Use simple momentum for the optimization.
            #opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
            opt = tf.train.AdamOptimizer(1e-3)
            network_fn = nets_factory.get_network_fn(
                FLAGS.model_name,
                num_classes=FLAGS.num_classes,
                weight_decay=FLAGS.weight_decay,
                is_training=istraining_placeholder)
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            X = images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size]
                            logits, _ = network_fn(X)
                            logits_lst.append(logits)
                            loss = tf.reduce_mean(
                                tf.nn.sigmoid_cross_entropy_with_logits(
                                    logits=logits,
                                    labels=labels_placeholder[
                                        gpu_index *
                                        FLAGS.batch_size:(gpu_index + 1) *
                                        FLAGS.batch_size]))
                            losses_lst.append(loss)
                            #varlist = [v for v in tf.trainable_variables() if any(x in v.name for x in ["logits"])]
                            varlist = tf.trainable_variables()
                            print([v.name for v in varlist])
                            grads = opt.compute_gradients(loss, varlist)
                            tower_grads.append(grads)
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()
            loss_op = tf.reduce_mean(losses_lst, name='softmax')
            logits_op = tf.concat(logits_lst, 0)
            grads = average_gradients(tower_grads)

            with tf.device('/cpu:%d' % 0):
                tvs = varlist
                accum_vars = [
                    tf.Variable(tf.zeros_like(tv.initialized_value()),
                                trainable=False) for tv in tvs
                ]
                zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print(update_ops)
            with tf.control_dependencies([tf.group(*update_ops)]):
                accum_ops = [
                    accum_vars[i].assign_add(gv[0] / FLAGS.n_minibatches)
                    for i, gv in enumerate(grads)
                ]

            apply_gradient_op = opt.apply_gradients(
                [(accum_vars[i].value(), gv[1]) for i, gv in enumerate(grads)],
                global_step=global_step)

            train_files = [
                os.path.join(FLAGS.train_images_files_dir, f)
                for f in os.listdir(FLAGS.train_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            val_files = [
                os.path.join(FLAGS.val_images_files_dir, f)
                for f in os.listdir(FLAGS.val_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            test_files = [
                os.path.join(FLAGS.test_images_files_dir, f)
                for f in os.listdir(FLAGS.test_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            print(
                '#############################Reading from files###############################'
            )
            print(train_files)
            print(val_files)

            tr_images_op, tr_labels_op = input_data.inputs(
                filenames=train_files,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num,
                num_epochs=None,
                num_threads=FLAGS.num_threads,
                num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                shuffle=True,
                distort=True,
            )
            val_images_op, val_labels_op = input_data.inputs(
                filenames=val_files,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num,
                num_epochs=None,
                num_threads=FLAGS.num_threads,
                num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                shuffle=True,
                distort=False,
            )
            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            print([var.name for var in bn_moving_vars])
            # Create a saver for writing training checkpoints.

            if use_pretrained_model:
                varlist = [
                    v for v in tf.trainable_variables()
                    if not any(x in v.name for x in ["logits"])
                ]
                #varlist += bn_moving_vars
                #vardict = {v.name[:-2].replace('MobileNet', 'MobilenetV1'): v for v in varlist}
                saver = tf.train.Saver(varlist)
                #saver = tf.train.Saver(vardict)
                if os.path.isfile(FLAGS.checkpoint_path):
                    saver.restore(sess, FLAGS.checkpoint_path)
                    print(
                        '#############################Session restored from pretrained model at {}!###############################'
                        .format(FLAGS.checkpoint_path))
                else:
                    ckpt = tf.train.get_checkpoint_state(
                        checkpoint_dir=FLAGS.checkpoint_path)
                    if ckpt and ckpt.model_checkpoint_path:
                        saver = tf.train.Saver(varlist)
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        print(
                            'Session restored from pretrained degradation model at {}!'
                            .format(ckpt.model_checkpoint_path))
                    else:
                        raise FileNotFoundError(errno.ENOENT,
                                                os.strerror(errno.ENOENT),
                                                FLAGS.checkpoint_path)
            else:
                ckpt = tf.train.get_checkpoint_state(
                    checkpoint_dir=FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver = tf.train.Saver(tf.trainable_variables())
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print(
                        'Session restored from pretrained degradation model at {}!'
                        .format(ckpt.model_checkpoint_path))
                else:
                    raise FileNotFoundError(errno.ENOENT,
                                            os.strerror(errno.ENOENT),
                                            FLAGS.checkpoint_dir)

            # Create summary writter
            saver = tf.train.Saver(tf.trainable_variables() + bn_moving_vars)
            for step in xrange(FLAGS.max_steps):
                start_time = time.time()
                loss_value_lst = []
                sess.run(zero_ops)
                for _ in itertools.repeat(None, FLAGS.n_minibatches):
                    tr_videos, tr_labels = sess.run(
                        [tr_images_op, tr_labels_op])
                    _, loss_value = sess.run(
                        [accum_ops, loss_op],
                        feed_dict={
                            images_placeholder: tr_videos,
                            labels_placeholder: tr_labels,
                            istraining_placeholder: True
                        })
                    loss_value_lst.append(loss_value)
                sess.run(apply_gradient_op)
                assert not np.isnan(
                    np.mean(loss_value_lst)), 'Model diverged with loss = NaN'
                duration = time.time() - start_time
                print('Step: {:4d} time: {:.4f} loss: {:.8f}'.format(
                    step, duration, np.mean(loss_value_lst)))
                if step % FLAGS.val_step == 0:
                    start_time = time.time()
                    tr_videos, tr_labels = sess.run(
                        [tr_images_op, tr_labels_op])
                    loss_value = sess.run(loss_op,
                                          feed_dict={
                                              images_placeholder: tr_videos,
                                              labels_placeholder: tr_labels,
                                              istraining_placeholder: True
                                          })
                    print("Step: {:4d} time: {:.4f}, training loss: {:.8f}".
                          format(step,
                                 time.time() - start_time, loss_value))

                    start_time = time.time()
                    val_videos, val_labels = sess.run(
                        [val_images_op, val_labels_op])
                    loss_value = sess.run(loss_op,
                                          feed_dict={
                                              images_placeholder: val_videos,
                                              labels_placeholder: val_labels,
                                              istraining_placeholder: True
                                          })
                    print("Step: {:4d} time: {:.4f}, validation loss: {:.8f}".
                          format(step,
                                 time.time() - start_time, loss_value))

                # Save a checkpoint and evaluate the model periodically.
                if step % FLAGS.save_step == 0 or (step +
                                                   1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

            coord.request_stop()
            coord.join(threads)

    print("done")
예제 #6
0
def run_testing():
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            images_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.gpu_num * FLAGS.batch_size, None, None,
                       FLAGS.nchannel))
            labels_placeholder = tf.placeholder(
                tf.float32,
                shape=(FLAGS.gpu_num * FLAGS.batch_size, FLAGS.num_classes))
            istraining_placeholder = tf.placeholder(tf.bool)
            logits_lst = []
            network_fn = nets_factory.get_network_fn(
                FLAGS.model_name,
                num_classes=FLAGS.num_classes,
                weight_decay=FLAGS.weight_decay,
                is_training=istraining_placeholder)
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            X = images_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size]
                            logits, _ = network_fn(X)
                            logits_lst.append(logits)
                            tf.get_variable_scope().reuse_variables()
            logits_op = tf.concat(logits_lst, 0)

            train_files = [
                os.path.join(FLAGS.train_images_files_dir, f)
                for f in os.listdir(FLAGS.train_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            val_files = [
                os.path.join(FLAGS.val_images_files_dir, f)
                for f in os.listdir(FLAGS.val_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            test_files = [
                os.path.join(FLAGS.test_images_files_dir, f)
                for f in os.listdir(FLAGS.test_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            print(
                '#############################Reading from files###############################'
            )
            print(train_files)
            print(val_files)

            images_op, labels_op = input_data.inputs(
                filenames=test_files,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num,
                num_epochs=1,
                num_threads=FLAGS.num_threads,
                num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                shuffle=True)

            config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config)
            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print(
                '----------------------------Trainable Variables-----------------------------------------'
            )
            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            saver = tf.train.Saver(tf.trainable_variables() + bn_moving_vars)
            ckpt = tf.train.get_checkpoint_state(
                checkpoint_dir=FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Session restored from pretrained budget model at {}!'.
                      format(ckpt.model_checkpoint_path))
            else:
                raise FileNotFoundError(errno.ENOENT,
                                        os.strerror(errno.ENOENT),
                                        FLAGS.checkpoint_dir)
            pred_probs_lst = []
            gt_lst = []
            try:
                while not coord.should_stop():
                    images, labels = sess.run([images_op, labels_op])
                    # write_video(videos, labels)
                    gt_lst.append(labels)
                    feed = {
                        images_placeholder: images,
                        labels_placeholder: labels,
                        istraining_placeholder: True
                    }

                    logits = sess.run(logits_op, feed_dict=feed)
                    pred_probs_lst.append(logits)
                    #print(logits)
                    # print(tf.argmax(softmax_logits, 1).eval(session=sess))
                    # print(logits.eval(feed_dict=feed, session=sess))
                    # print(labels)
            except tf.errors.OutOfRangeError:
                print('Done testing on all the examples')
            finally:
                coord.request_stop()

            pred_probs_mat = np.concatenate(pred_probs_lst, axis=0)
            gt_mat = np.concatenate(gt_lst, axis=0)
            n_examples, n_labels = gt_mat.shape
            print('# Examples = ', n_examples)
            print('# Labels = ', n_labels)
            print('Macro MAP = {:.2f}'.format(100 * average_precision_score(
                gt_mat, pred_probs_mat, average='macro')))
            cmap_stats = average_precision_score(gt_mat,
                                                 pred_probs_mat,
                                                 average=None)
            attr_id_to_name, attr_id_to_idx = load_attributes()
            idx_to_attr_id = {v: k for k, v in attr_id_to_idx.items()}
            with open('class_scores.txt', 'w') as wf:
                wf.write('\t'.join([
                    'attribute_id', 'attribute_name', 'num_occurrences', 'ap'
                ]) + '\n')
                for idx in range(n_labels):
                    attr_id = idx_to_attr_id[idx]
                    attr_name = attr_id_to_name[attr_id]
                    attr_occurrences = np.sum(gt_mat, axis=0)[idx]
                    ap = cmap_stats[idx]
                    wf.write('{}\t{}\t{}\t{}\n'.format(attr_id, attr_name,
                                                       attr_occurrences,
                                                       ap * 100.0))

            coord.join(threads)
            sess.close()

    print("done")
def train(re_train=True):
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)

    # Get images and labels for CIFAR-10.
    # images, labels = my_input.inputs()
    images, labels = input_data.distorted_inputs()
    val_images, val_labels = input_data.inputs(False)

    # Build a Graph that computes the logits predictions from the inference model.
    logits = my_cifar.inference(images_placeholder)

    # Calculate loss.
    loss = my_cifar.loss(logits, labels_placeholder)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = my_cifar.training(loss, global_step)

    # Calculate accuracy #
    acc, n_correct = my_cifar.evaluation(logits, labels_placeholder)

    # Create a saver.
    saver = tf.train.Saver()

    tf.scalar_summary('Acc', acc)
    # tf.scalar_summary('Val Acc', acc_val)
    tf.scalar_summary('Loss', loss)
    tf.image_summary('Images', tf.reshape(images, shape=[-1, 32, 32, 3]), max_images=10)
    tf.image_summary('Val Images', tf.reshape(val_images, shape=[-1, 32, 32, 3]), max_images=10)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    # NUM_CORES = 2  # Choose how many cores to use.
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement, ))
    # inter_op_parallelism_threads=NUM_CORES,
    # intra_op_parallelism_threads=NUM_CORES))
    sess.run(init)

    # Write all terminal output results here
    val_f = open("tmp/val.txt", "ab")

    # Start the queue runners.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dirr,
                                            graph_def=sess.graph_def)

    if re_train:

      # Export graph to import it later in c++
      # tf.train.write_graph(sess.graph_def, FLAGS.model_dir, 'train.pbtxt') # TODO: uncomment to get graph and use in c++

      continue_from_pre = False

      if continue_from_pre:
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
        print ckpt.model_checkpoint_path
        if ckpt and ckpt.model_checkpoint_path:
          saver.restore(sess, ckpt.model_checkpoint_path)
          print('Session Restored!')

      try:
        while not coord.should_stop():

          for step in xrange(FLAGS.max_steps):

            images_r, labels_r = sess.run([images, labels])
            images_val_r, labels_val_r = sess.run([val_images, val_labels])

            train_feed = {images_placeholder: images_r,
                          labels_placeholder: labels_r}

            val_feed = {images_placeholder: images_val_r,
                        labels_placeholder: labels_val_r}

            start_time = time.time()

            _, loss_value = sess.run([train_op, loss], feed_dict=train_feed)
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % display_step == 0:
              num_examples_per_step = FLAGS.batch_size
              examples_per_sec = num_examples_per_step / duration
              sec_per_batch = float(duration)

              format_str = ('%s: step %d, loss = %.6f (%.1f examples/sec; %.3f '
                            'sec/batch)')
              print_str_loss = format_str % (datetime.now(), step, loss_value,
                                             examples_per_sec, sec_per_batch)
              print (print_str_loss)
              val_f.write(print_str_loss + NEW_LINE)
              summary_str = sess.run([summary_op], feed_dict=train_feed)
              summary_writer.add_summary(summary_str[0], step)

            if step % val_step == 0:
              acc_value, num_corroect = sess.run([acc, n_correct], feed_dict=train_feed)

              format_str = '%s: step %d,  train acc = %.2f, n_correct= %d'
              print_str_train = format_str % (datetime.now(), step, acc_value, num_corroect)
              val_f.write(print_str_train + NEW_LINE)
              print (print_str_train)

            # Save the model checkpoint periodically.
            if step % save_step == 0 or (step + 1) == FLAGS.max_steps:
              val_acc_r, val_n_correct_r = sess.run([acc, n_correct], feed_dict=val_feed)

              frmt_str = ' step %d, Val Acc = %.2f, num correct = %d'
              print_str_val = frmt_str % (step, val_acc_r, val_n_correct_r)
              val_f.write(print_str_val + NEW_LINE)
              print(print_str_val)

              checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'model.ckpt')
              saver.save(sess, checkpoint_path, global_step=step)


      except tf.errors.OutOfRangeError:
        print ('Done training -- epoch limit reached')

      finally:
        # When done, ask the threads to stop.
        val_f.write(NEW_LINE +
                    NEW_LINE +
                    '############################ FINISHED ############################' +
                    NEW_LINE)
        val_f.close()
        coord.request_stop()

      # Wait for threads to finish.
      coord.join(threads)
      sess.close()

    else:

      ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
      print ckpt.model_checkpoint_path
      if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Restored!')

      for i in range(100):
        images_val_r, labels_val_r = sess.run([val_images, val_labels])
        val_feed = {images_placeholder: images_val_r,
                    labels_placeholder: labels_val_r}

        tf.scalar_summary('Acc', acc)

        print('Calculating Acc: ')

        acc_r = sess.run(acc, feed_dict=val_feed)
        print(acc_r)

    coord.join(threads)
    sess.close()
예제 #8
0
def train(continue_from_pre=True):
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)
        images_placeholder, labels_placeholder, dropout_placeholder = placeholder_inputs(
            FLAGS.batch_size)

        # Get images and labels by shuffled batch
        tr_images, tr_labels = input_data.inputs(
            filename='./tmp/tfrecords/train.tfrecords',
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            num_threads=5,
            imshape=[128, 128, 3],
            use_distortion=False)
        val_images, val_labels = input_data.inputs(
            filename='./tmp/tfrecords/validation.tfrecords',
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            num_threads=5,
            imshape=[128, 128, 3],
            use_distortion=False)
        # Build a Graph that computes the logits predictions from the inference model.
        logits = convnet_model.inference(images_placeholder,
                                         dropout_placeholder)
        # Calculate loss.
        loss = convnet_model.loss(logits, labels_placeholder)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = convnet_model.training(loss, global_step)

        # Calculate accuracy #
        acc, n_correct = convnet_model.evaluation(logits, labels_placeholder)

        # Create a saver.
        saver = tf.train.Saver()

        tf.scalar_summary('Training Acc', acc)
        tf.scalar_summary('Training Loss', loss)
        tf.image_summary('Training Images',
                         tf.reshape(tr_images, shape=[-1, 128, 128, 3]),
                         max_images=20)
        tf.image_summary('Validation Images',
                         tf.reshape(val_images, shape=[-1, 128, 128, 3]),
                         max_images=20)
        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()
        # Start running operations on the Graph.
        # NUM_CORES = 2  # Choose how many cores to use.
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement, ))
        # inter_op_parallelism_threads=NUM_CORES,
        # intra_op_parallelism_threads=NUM_CORES))
        sess.run(init)

        # Write all terminal output results here
        output_f = open("tmp/output.txt", "ab")
        tr_f = open("tmp/training_accuracy.txt", "ab")
        val_f = open("tmp/validation_accuracy.txt", "ab")
        # Start the queue runners.ValueError: No variables to save

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                                graph_def=sess.graph_def)

        if continue_from_pre:
            ckpt = tf.train.get_checkpoint_state(
                checkpoint_dir=FLAGS.checkpoint_dir)
            print ckpt.model_checkpoint_path
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print 'Session Restored!'
        try:
            while not coord.should_stop():
                for step in xrange(FLAGS.max_steps_train):
                    tr_images_r, tr_labels_r = sess.run([tr_images, tr_labels])
                    print tr_images_r.shape
                    print tr_labels_r.shape
                    val_images_r, val_labels_r = sess.run(
                        [val_images, val_labels])

                    # Feed operation for training and validation
                    tr_feed = {
                        images_placeholder: tr_images_r,
                        labels_placeholder: tr_labels_r,
                        dropout_placeholder: 0.5
                    }
                    val_feed = {
                        images_placeholder: val_images_r,
                        labels_placeholder: val_labels_r,
                        dropout_placeholder: 1.0
                    }
                    start_time = time.time()
                    _, loss_value = sess.run([train_op, loss],
                                             feed_dict=tr_feed)
                    duration = time.time() - start_time
                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'
                    if step % display_step == 0:
                        num_examples_per_step = FLAGS.batch_size
                        examples_per_sec = num_examples_per_step / duration
                        sec_per_batch = float(duration)
                        format_str = (
                            '%s: step %d, loss = %.6f (%.1f examples/sec; %.3f sec/batch)'
                        )
                        print_str_loss = format_str % (datetime.now(
                        ), step, loss_value, examples_per_sec, sec_per_batch)
                        print(print_str_loss)
                        output_f.write(print_str_loss + '\n')
                        summary_str = sess.run([summary_op], feed_dict=tr_feed)
                        summary_writer.add_summary(summary_str[0], step)

                    if step % validation_step == 0:
                        tr_acc, tr_n_correct = sess.run([acc, n_correct],
                                                        feed_dict=tr_feed)
                        format_str = '%s: step %d,  training accuracy = %.2f, n_correct= %d'
                        print_str = format_str % (datetime.now(), step, tr_acc,
                                                  tr_n_correct)
                        output_f.write(print_str + '\n')
                        tr_f.write(str(tr_acc) + '\n')
                        print(print_str)
                        val_acc, val_n_correct = sess.run([acc, n_correct],
                                                          feed_dict=val_feed)
                        format_str = '%s: step %d,  validation accuracy = %.2f, n_correct= %d'
                        print_str = format_str % (datetime.now(), step,
                                                  val_acc, val_n_correct)
                        output_f.write(print_str + '\n')
                        val_f.write(str(val_acc) + '\n')
                        print(print_str)

                    if step % save_step == 0 or (step +
                                                 1) == FLAGS.max_steps_train:
                        checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                                       'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=step)
        except tf.errors.OutOfRangeError:
            print('Done traning -- epoch limit')
        finally:
            output_f.write(
                '********************Finish Training********************')
            tr_f.write(
                '********************Finish Training********************')
            val_f.write(
                '********************Finish Training********************')
            coord.request_stop()

        coord.join(threads)
        sess.close()
예제 #9
0
def train():
	with tf.Graph ().as_default ():
		phase_train = tf.placeholder (tf.bool, name='phase_train')
		global_step = tf.Variable (0, trainable=False, name='global_step')

		# Inputs
		train_image_batch, train_label_batch = input_data.distorted_inputs ()
		val_image_batch, val_label_batch = input_data.inputs (True)
		image_batch, label_batch = control_flow_ops.cond (phase_train,
			lambda: (train_image_batch, train_label_batch),
			lambda: (val_image_batch, val_label_batch))

		# Model
		logits = m.inference (image_batch, phase_train)

		# Loss
		loss, cross_entropy_mean = m.loss (logits, label_batch)

		# Training
		train_op = m.train(loss, global_step)

		# Saver
		saver = tf.train.Saver(tf.all_variables())

		# Session
		sess = tf.Session (config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))

		# Summary
		summary_op = tf.merge_all_summaries()
		summary_writer = tf.train.SummaryWriter (FLAGS.train_dir, graph=sess.graph)

		# Init
		init_op = tf.initialize_all_variables()
		print ('Initializing...')
		sess.run (init_op, {phase_train.name: True})

		# Start the queue runners
		tf.train.start_queue_runners (sess=sess)

		# Training loop
		print ('Training...')

		for step in xrange(FLAGS.max_steps):
			fetches = [train_op, loss, cross_entropy_mean]
			if step > 0 and step % 100 == 0:
				fetches += [summary_op]

			start_time = time.time ()
			sess_outputs = sess.run (fetches, {phase_train.name: True})
			duration = time.time () - start_time

			loss_value, cross_entropy_value = sess_outputs[1:3]

			if step % 10 == 0:
				num_examples_per_step = FLAGS.batch_size
				examples_per_sec = num_examples_per_step / duration
				sec_per_batch = float(duration)

				format_str = ('%s: step %d, loss = %.2f (%.4f) (%.1f examples/sec; %.3f sec/batch)')
				print (format_str % (datetime.now(), step, loss_value, cross_entropy_value, examples_per_sec, sec_per_batch))
			
			# Summary
			if step > 0 and step % 100 == 0:
				summary_str = sess_outputs[3]
				summary_writer.add_summary (summary_str, step)

			# Validation
			if step > 0 and step % 1000 == 0:
				n_val_samples = 10000
				val_batch_size = FLAGS.batch_size
				n_val_batch = int (n_val_samples / val_batch_size)
				val_logits = np.zeros ((n_val_samples, 2), dtype=np.float32)
				val_labels = np.zeros ((n_val_samples), dtype=np.int64)
				val_losses = []

				for i in xrange (n_val_batch):
					session_outputs = sess.run ([logits, label_batch, loss], {phase_train.name: False})
					val_logits[i*val_batch_size:(i+1)*val_batch_size, :] = session_outputs[0]
					val_labels[i*val_batch_size:(i+1)*val_batch_size] = session_outputs[1]
					val_losses.append (session_outputs[2])

				pred_labels = np.argmax (val_logits, axis=1)
				val_accuracy = np.count_nonzero (pred_labels == val_labels) / (n_val_batch * val_batch_size)
				val_loss = float (np.mean (np.asarray (val_losses)))
				print ('Test accuracy = %f' % val_accuracy)
				print ('Test loss = %f' % val_loss)
				val_summary = tf.Summary ()
				val_summary.value.add (tag='val_accuracy', simple_value=val_accuracy)
				val_summary.value.add (tag='val_loss', simple_value=val_loss)
				summary_writer.add_summary (val_summary, step)


			# Save variables
			if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
				checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
				saver.save(sess, checkpoint_path, global_step=step)
def run_training():
    # Get the sets of images and labels for training, validation, and
    # Tell TensorFlow that the model will be built into the default Graph.

    # Create model directory
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.log_dir):
        os.makedirs(FLAGS.log_dir)

    use_pretrained_model = True

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            videos_placeholder, labels_placeholder, dropout_placeholder = placeholder_inputs(
            )
            tower_grads1 = []
            tower_grads2 = []
            logits = []
            losses = []
            opt_stable = tf.train.AdamOptimizer(1e-4)
            opt_finetuning = tf.train.AdamOptimizer(1e-5)
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            logit = c3d_model.inference_c3d(
                                videos_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size],
                                dropout_placeholder, FLAGS.batch_size)
                            loss = tower_loss_xentropy(
                                scope, logit,
                                labels_placeholder[gpu_index *
                                                   FLAGS.batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.batch_size])
                            losses.append(loss)
                            varlist1 = [
                                v for v in tf.trainable_variables()
                                if not any(x in v.name for x in ["out", "d2"])
                            ]
                            varlist2 = [
                                v for v in tf.trainable_variables()
                                if any(x in v.name for x in ["out", "d2"])
                            ]

                            print(
                                '######################varlist1######################'
                            )
                            print([v.name for v in varlist1])
                            print(
                                '######################varlist2######################'
                            )
                            print([v.name for v in varlist2])
                            #grads1 = opt_stable.compute_gradients(loss, varlist1)
                            grads2 = opt_finetuning.compute_gradients(
                                loss, varlist2)
                            #tower_grads1.append(grads1)
                            tower_grads2.append(grads2)
                            logits.append(logit)
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()
            logits = tf.concat(logits, 0)
            loss_op = tf.reduce_mean(losses, name='softmax')
            accuracy = c3d_model.accuracy(logits, labels_placeholder)
            tf.summary.scalar('accuracy', accuracy)

            #grads1 = average_gradients(tower_grads1)
            grads2 = average_gradients(tower_grads2)

            #apply_gradient_op1 = opt_stable.apply_gradients(grads1, global_step=global_step)
            apply_gradient_op2 = opt_finetuning.apply_gradients(
                grads2, global_step=global_step)
            #train_op = tf.group(apply_gradient_op1, apply_gradient_op2, variables_averages_op)
            train_op = tf.group(apply_gradient_op2)

            train_files = [
                os.path.join(FLAGS.training_data, f)
                for f in os.listdir(FLAGS.training_data)
                if f.endswith('.tfrecords')
            ]
            val_files = [
                os.path.join(FLAGS.validation_data, f)
                for f in os.listdir(FLAGS.validation_data)
                if f.endswith('.tfrecords')
            ]
            print(train_files)
            print(val_files)
            tr_videos_op, tr_labels_op = input_data.inputs(
                filenames=train_files,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num,
                num_epochs=None,
                num_threads=FLAGS.num_threads,
                num_examples_per_epoch=FLAGS.num_examples_per_epoch)
            val_videos_op, val_labels_op = input_data.inputs(
                filenames=val_files,
                batch_size=FLAGS.batch_size * FLAGS.gpu_num,
                num_epochs=None,
                num_threads=FLAGS.num_threads,
                num_examples_per_epoch=FLAGS.num_examples_per_epoch)
            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            # Model restoring.
            if use_pretrained_model:
                if os.path.isfile(FLAGS.pretrained_model):
                    varlist = [
                        v for v in tf.trainable_variables()
                        if not any(x in v.name.split('/')[1]
                                   for x in ["out", "d2"])
                    ]
                    vardict = {
                        v.name[:-2].replace('C3DNet', 'var_name'): v
                        for v in varlist
                    }
                    for key, value in vardict.items():
                        print(key)
                    saver = tf.train.Saver(vardict)
                    saver.restore(sess, FLAGS.pretrained_model)
                    print(
                        'Session restored from pretrained model at {}!'.format(
                            FLAGS.pretrained_model))
                else:
                    raise FileNotFoundError(errno.ENOENT,
                                            os.strerror(errno.ENOENT),
                                            FLAGS.pretrained_model)
            else:
                saver = tf.train.Saver()
                ckpt = tf.train.get_checkpoint_state(
                    checkpoint_dir=FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Session restored from trained model at {}!'.format(
                        ckpt.model_checkpoint_path))
                else:
                    raise FileNotFoundError(errno.ENOENT,
                                            os.strerror(errno.ENOENT),
                                            FLAGS.checkpoint_dir)

            # Create summary writter
            merge_op = tf.summary.merge_all()
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + 'train',
                                                 sess.graph)
            test_writer = tf.summary.FileWriter(FLAGS.log_dir + 'test',
                                                sess.graph)
            saver = tf.train.Saver(tf.trainable_variables())
            for step in range(FLAGS.max_steps):
                start_time = time.time()
                tr_videos, tr_labels = sess.run([tr_videos_op, tr_labels_op])
                _, loss_value = sess.run(
                    [train_op, loss_op],
                    feed_dict={
                        videos_placeholder: tr_videos,
                        labels_placeholder: tr_labels,
                        dropout_placeholder: 0.5
                    })
                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'
                duration = time.time() - start_time
                print('Step: {:4d} time: {:.4f} loss: {:.8f}'.format(
                    step, duration, loss_value))
                if step % FLAGS.val_step == 0:
                    start_time = time.time()
                    tr_videos, tr_labels = sess.run(
                        [tr_videos_op, tr_labels_op])
                    summary, acc, loss_value = sess.run(
                        [merge_op, accuracy, loss_op],
                        feed_dict={
                            videos_placeholder: tr_videos,
                            labels_placeholder: tr_labels,
                            dropout_placeholder: 1.0
                        })
                    print(
                        "Step: {:4d} time: {:.4f}, training accuracy: {:.5f}, loss: {:.8f}"
                        .format(step,
                                time.time() - start_time, acc, loss_value))
                    train_writer.add_summary(summary, step)

                    start_time = time.time()
                    val_videos, val_labels = sess.run(
                        [val_videos_op, val_labels_op])
                    summary, acc, loss_value = sess.run(
                        [merge_op, accuracy, loss_op],
                        feed_dict={
                            videos_placeholder: val_videos,
                            labels_placeholder: val_labels,
                            dropout_placeholder: 1.0
                        })
                    print(
                        "Step: {:4d} time: {:.4f}, validation accuracy: {:.5f}, loss: {:.8f}"
                        .format(step,
                                time.time() - start_time, acc, loss_value))
                    test_writer.add_summary(summary, step)
                # Save a checkpoint and evaluate the model periodically.
                if step % FLAGS.save_step == 0 or (step +
                                                   1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

            coord.request_stop()
            coord.join(threads)

    print("done")
def run_testing():
    videos_placeholder, labels_placeholder, dropout_placeholder = placeholder_inputs(
    )
    logits = []
    with tf.variable_scope(tf.get_variable_scope()) as scope:
        for gpu_index in range(0, FLAGS.gpu_num):
            with tf.device('/gpu:%d' % gpu_index):
                print('/gpu:%d' % gpu_index)
                with tf.name_scope('%s_%d' % ('gpu', gpu_index)) as scope:
                    logit = c3d_model.inference_c3d(
                        videos_placeholder[gpu_index *
                                           FLAGS.batch_size:(gpu_index + 1) *
                                           FLAGS.batch_size],
                        dropout_placeholder, FLAGS.batch_size)
                    logits.append(logit)
                    # Reuse variables for the next tower.
                    tf.get_variable_scope().reuse_variables()
    logits = tf.concat(logits, 0)
    right_count = tf.reduce_sum(
        tf.cast(
            tf.equal(tf.argmax(tf.nn.softmax(logits), axis=1),
                     labels_placeholder), tf.int32))
    softmax_logits_op = tf.nn.softmax(logits)

    train_files = [
        os.path.join(FLAGS.training_data, f)
        for f in os.listdir(FLAGS.training_data) if f.endswith('.tfrecords')
    ]
    val_files = [
        os.path.join(FLAGS.validation_data, f)
        for f in os.listdir(FLAGS.validation_data) if f.endswith('.tfrecords')
    ]
    videos_op, labels_op = input_data.inputs(
        filenames=train_files,
        batch_size=FLAGS.batch_size * FLAGS.gpu_num,
        num_epochs=1,
        num_threads=FLAGS.num_threads,
        num_examples_per_epoch=FLAGS.num_examples_per_epoch,
        shuffle=False)

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    init_op = tf.group(tf.local_variables_initializer(),
                       tf.global_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    saver = tf.train.Saver(tf.trainable_variables())
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Session Restored from {}!'.format(ckpt.model_checkpoint_path))
    total_v = 0.0
    test_correct_num = 0.0
    try:
        while not coord.should_stop():
            videos, labels = sess.run([videos_op, labels_op])
            #write_video(videos, labels)
            feed = {
                videos_placeholder: videos,
                labels_placeholder: labels,
                dropout_placeholder: 1.0
            }
            right, softmax_logits = sess.run([right_count, softmax_logits_op],
                                             feed_dict=feed)
            test_correct_num += right
            total_v += labels.shape[0]
            print(softmax_logits.shape)
            print(tf.argmax(softmax_logits, 1).eval(session=sess))
            print(labels)
    except tf.errors.OutOfRangeError:
        print('Done testing on all the examples')
    finally:
        coord.request_stop()
    print('test acc:', test_correct_num / total_v, 'test_correct_num:',
          test_correct_num, 'total_v:', total_v)
    coord.join(threads)
    sess.close()