def nsfw_eval_dataset(dataset_dir, weights_path):
    """
    Evaluate the nsfw dataset
    :param dataset_dir: The nsfw dataset dir which contains tensorflow records file
    :param weights_path: The pretrained nsfw model weights file path
    :return:
    """
    assert ops.exists(dataset_dir)

    # set nsfw data feed pipline
    test_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(
        dataset_dir=dataset_dir, flags='test')
    prediciton_map = test_dataset.prediction_map
    class_names = ['drawing', 'hentai', 'neural', 'p**n', 'sexy']

    with tf.device('/gpu:1'):
        # set nsfw classification model
        phase = tf.constant('test', dtype=tf.string)

        # set nsfw net
        nsfw_net = nsfw_classification_net.NSFWNet(
            phase=phase, resnet_size=CFG.NET.RESNET_SIZE)

        # compute train loss
        images, labels = test_dataset.inputs(batch_size=CFG.TEST.BATCH_SIZE,
                                             num_epochs=1)

        logits = nsfw_net.inference(input_tensor=images,
                                    name='nsfw_cls_model',
                                    reuse=False)

        predictions = tf.nn.softmax(logits)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(
        CFG.TRAIN.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()

    # set tensorflow saver
    saver = tf.train.Saver(variables_to_restore)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    # labels overall test dataset
    labels_total = []
    # prediction result overall test dataset
    predictions_total = []
    # prediction score overall test dataset of all subclass
    predictions_prob_total = []

    with sess.as_default():

        saver.restore(sess=sess, save_path=weights_path)

        while True:
            try:
                predictions_vals, labels_vals = sess.run(
                    fetches=[predictions, labels])

                log.info('**************')
                log.info('Test dataset batch size: {:d}'.format(
                    predictions_vals.shape[0]))
                log.info('---- Sample Id ---- Gt label ---- Prediction ----')

                for index, predictions_val in enumerate(predictions_vals):

                    label_gt = prediciton_map[labels_vals[index]]

                    prediction_score = dict()

                    for score_index, score in enumerate(predictions_val):
                        prediction_score[prediciton_map[score_index]] = format(
                            score, '.5f')

                    log.info('---- {:d} ---- {:s} ---- {}'.format(
                        index, label_gt, prediction_score))

                    # record predicts prob map
                    predictions_prob_total.append(predictions_val.tolist())

                # record total label and prediction results
                labels_total.extend(labels_vals.tolist())
                predictions_total.extend(
                    np.argmax(predictions_vals, axis=1).tolist())

            except tf.errors.OutOfRangeError as err:
                log.info('Loop overall the test dataset')
                break
            except Exception as err:
                log.error(err)
                break

    # print prediction report
    print('Nsfw classification_report(left: labels):')
    print(classification_report(labels_total, predictions_total))

    # calculate confusion matrix
    cnf_matrix = confusion_matrix(labels_total, predictions_total)
    np.set_printoptions(precision=2)
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          normalize=True,
                          title='Normalized confusion matrix')

    # calculate evaluate statics
    calculate_evaluate_statics(labels=labels_total,
                               predictions=predictions_total)

    # plot precision recall curve
    plot_precision_recall_curve(labels=labels_total,
                                predictions_prob=predictions_prob_total,
                                class_nums=5)
    plt.show()

    return
def train_net(dataset_dir, weights_path=None):
    """

    :param dataset_dir:
    :param weights_path:
    :return:
    """

    # set nsfw data feed pipline
    train_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(
        dataset_dir=dataset_dir, flags='train')
    val_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(
        dataset_dir=dataset_dir, flags='val')

    with tf.device('/gpu:1'):
        # set nsfw net
        nsfw_net = nsfw_classification_net.NSFWNet(
            phase=tf.constant('train', dtype=tf.string),
            resnet_size=CFG.NET.RESNET_SIZE)
        nsfw_net_val = nsfw_classification_net.NSFWNet(
            phase=tf.constant('test', dtype=tf.string),
            resnet_size=CFG.NET.RESNET_SIZE)

        # compute train loss
        train_images, train_labels = train_dataset.inputs(
            batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1)
        train_loss = nsfw_net.compute_loss(input_tensor=train_images,
                                           labels=train_labels,
                                           name='nsfw_cls_model',
                                           reuse=False)

        train_logits = nsfw_net.inference(input_tensor=train_images,
                                          name='nsfw_cls_model',
                                          reuse=True)

        train_predictions = tf.nn.softmax(train_logits)
        train_top1_error = calculate_top_k_error(train_predictions,
                                                 train_labels, 1)

        # compute val loss
        val_images, val_labels = val_dataset.inputs(
            batch_size=CFG.TRAIN.VAL_BATCH_SIZE, num_epochs=1)
        # val_images = tf.reshape(val_images, example_tensor_shape)
        val_loss = nsfw_net_val.compute_loss(input_tensor=val_images,
                                             labels=val_labels,
                                             name='nsfw_cls_model',
                                             reuse=True)

        val_logits = nsfw_net_val.inference(input_tensor=val_images,
                                            name='nsfw_cls_model',
                                            reuse=True)

        val_predictions = tf.nn.softmax(val_logits)
        val_top1_error = calculate_top_k_error(val_predictions, val_labels, 1)

    # set tensorflow summary
    tboard_save_path = 'tboard/nsfw_cls'
    os.makedirs(tboard_save_path, exist_ok=True)

    summary_writer = tf.summary.FileWriter(tboard_save_path)

    train_loss_scalar = tf.summary.scalar(name='train_loss', tensor=train_loss)
    train_top1_err_scalar = tf.summary.scalar(name='train_top1_error',
                                              tensor=train_top1_error)
    val_loss_scalar = tf.summary.scalar(name='val_loss', tensor=val_loss)
    val_top1_err_scalar = tf.summary.scalar(name='val_top1_error',
                                            tensor=val_top1_error)

    train_merge_summary_op = tf.summary.merge(
        [train_loss_scalar, train_top1_err_scalar])

    val_merge_summary_op = tf.summary.merge(
        [val_loss_scalar, val_top1_err_scalar])

    # Set tf saver
    saver = tf.train.Saver()
    model_save_dir = 'model/nsfw_cls'
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'nsfw_cls_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # set optimizer
    with tf.device('/gpu:1'):
        # set learning rate
        global_step = tf.Variable(0, trainable=False)
        decay_steps = [CFG.TRAIN.LR_DECAY_STEPS_1, CFG.TRAIN.LR_DECAY_STEPS_2]
        decay_values = []
        init_lr = CFG.TRAIN.LEARNING_RATE
        for step in range(len(decay_steps) + 1):
            decay_values.append(init_lr)
            init_lr = init_lr * CFG.TRAIN.LR_DECAY_RATE

        learning_rate = tf.train.piecewise_constant(x=global_step,
                                                    boundaries=decay_steps,
                                                    values=decay_values,
                                                    name='learning_rate')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=0.9).minimize(loss=train_loss,
                                       var_list=tf.trainable_variables(),
                                       global_step=global_step)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/nsfw_cls_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        train_cost_time_mean = []
        val_cost_time_mean = []

        for epoch in range(train_epochs):

            # training part
            t_start = time.time()

            _, train_loss_value, train_top1_err_value, train_summary, lr = \
                sess.run(fetches=[optimizer,
                                  train_loss,
                                  train_top1_error,
                                  train_merge_summary_op,
                                  learning_rate])

            if math.isnan(train_loss_value):
                log.error('Train loss is nan')
                return

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)

            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # validation part
            t_start_val = time.time()

            val_loss_value, val_top1_err_value, val_summary = \
                sess.run(fetches=[val_loss,
                                  val_top1_error,
                                  val_merge_summary_op])

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info(
                    'Epoch_Train: {:d} total_loss= {:6f} top1_error= {:6f} '
                    'lr= {:6f} mean_cost_time= {:5f}s '.format(
                        epoch + 1, train_loss_value, train_top1_err_value, lr,
                        np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                log.info('Epoch_Val: {:d} total_loss= {:6f} top1_error= {:6f}'
                         ' mean_cost_time= {:5f}s '.format(
                             epoch + 1, val_loss_value, val_top1_err_value,
                             np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
    sess.close()

    return
def train_net_multi_gpu(dataset_dir, weights_path=None):
    """

    :param dataset_dir:
    :param weights_path:
    :return:
    """
    # set nsfw data feed pipline
    train_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(
        dataset_dir=dataset_dir, flags='train')
    val_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(
        dataset_dir=dataset_dir, flags='val')

    # set nsfw net
    nsfw_net = nsfw_classification_net.NSFWNet(phase=tf.constant(
        'train', dtype=tf.string),
                                               resnet_size=CFG.NET.RESNET_SIZE)
    nsfw_net_val = nsfw_classification_net.NSFWNet(
        phase=tf.constant('test', dtype=tf.string),
        resnet_size=CFG.NET.RESNET_SIZE)

    # fetch train and validation data
    train_images, train_labels = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1)
    val_images, val_labels = val_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1)

    # set average container
    tower_grads = []
    train_tower_loss = []
    train_tower_top1_error = []
    val_tower_loss = []
    val_tower_top1_error = []
    batchnorm_updates = None
    train_summary_op_updates = None

    # set learning rate
    global_step = tf.Variable(0, trainable=False)
    decay_steps = [CFG.TRAIN.LR_DECAY_STEPS_1, CFG.TRAIN.LR_DECAY_STEPS_2]
    decay_values = []
    init_lr = CFG.TRAIN.LEARNING_RATE
    for step in range(len(decay_steps) + 1):
        decay_values.append(init_lr)
        init_lr = init_lr * CFG.TRAIN.LR_DECAY_RATE

    learning_rate = tf.train.piecewise_constant(x=global_step,
                                                boundaries=decay_steps,
                                                values=decay_values,
                                                name='learning_rate')

    # set optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=0.9)

    # set distributed train op
    with tf.variable_scope(tf.get_variable_scope()):
        is_network_initialized = False
        for i in range(CFG.TRAIN.GPU_NUM):
            with tf.device('/gpu:{:d}'.format(i)):
                with tf.name_scope('tower_{:d}'.format(i)) as scope:
                    train_loss, train_top1_error, grads = compute_net_gradients(
                        train_images,
                        train_labels,
                        nsfw_net,
                        optimizer,
                        is_net_first_initialized=is_network_initialized)

                    is_network_initialized = True

                    # Only use the mean and var in the first gpu tower to update the parameter
                    # TODO implement batch normalization for distributed device ([email protected])
                    if i == 0:
                        batchnorm_updates = tf.get_collection(
                            tf.GraphKeys.UPDATE_OPS)
                        train_summary_op_updates = tf.get_collection(
                            tf.GraphKeys.SUMMARIES)

                    tower_grads.append(grads)
                    train_tower_loss.append(train_loss)
                    train_tower_top1_error.append(train_top1_error)
                with tf.name_scope('validation_{:d}'.format(i)) as scope:
                    val_loss, val_top1_error, _ = compute_net_gradients(
                        val_images,
                        val_labels,
                        nsfw_net_val,
                        optimizer,
                        is_net_first_initialized=is_network_initialized)
                    val_tower_loss.append(val_loss)
                    val_tower_top1_error.append(val_top1_error)

    grads = average_gradients(tower_grads)
    avg_train_loss = tf.reduce_mean(train_tower_loss)
    avg_train_top1_error = tf.reduce_mean(train_tower_top1_error)
    avg_val_loss = tf.reduce_mean(val_tower_loss)
    avg_val_top1_error = tf.reduce_mean(val_tower_top1_error)

    # Track the moving averages of all trainable variables
    variable_averages = tf.train.ExponentialMovingAverage(
        CFG.TRAIN.MOVING_AVERAGE_DECAY, num_updates=global_step)
    variables_to_average = tf.trainable_variables(
    ) + tf.moving_average_variables()
    variables_averages_op = variable_averages.apply(variables_to_average)

    # Group all the op needed for training
    batchnorm_updates_op = tf.group(*batchnorm_updates)
    apply_gradient_op = optimizer.apply_gradients(grads,
                                                  global_step=global_step)
    train_op = tf.group(apply_gradient_op, variables_averages_op,
                        batchnorm_updates_op)

    # set tensorflow summary
    tboard_save_path = 'tboard/nsfw_cls'
    os.makedirs(tboard_save_path, exist_ok=True)

    summary_writer = tf.summary.FileWriter(tboard_save_path)

    avg_train_loss_scalar = tf.summary.scalar(name='average_train_loss',
                                              tensor=avg_train_loss)
    avg_train_top1_err_scalar = tf.summary.scalar(
        name='average_train_top1_error', tensor=avg_train_top1_error)
    avg_val_loss_scalar = tf.summary.scalar(name='average_val_loss',
                                            tensor=avg_val_loss)
    avg_val_top1_err_scalar = tf.summary.scalar(name='average_val_top1_error',
                                                tensor=avg_val_top1_error)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate_scalar',
                                             tensor=learning_rate)

    train_merge_summary_op = tf.summary.merge([
        avg_train_loss_scalar, avg_train_top1_err_scalar, learning_rate_scalar
    ] + train_summary_op_updates)

    val_merge_summary_op = tf.summary.merge(
        [avg_val_loss_scalar, avg_val_top1_err_scalar])

    # set tensorflow saver
    saver = tf.train.Saver()
    model_save_dir = 'model/nsfw_cls'
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'nsfw_cls_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # set sess config
    sess_config = tf.ConfigProto(device_count={'GPU': CFG.TRAIN.GPU_NUM},
                                 allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    sess = tf.Session(config=sess_config)

    summary_writer.add_graph(sess.graph)

    with sess.as_default():

        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/nsfw_cls_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        train_cost_time_mean = []
        val_cost_time_mean = []

        for epoch in range(train_epochs):

            # training part
            t_start = time.time()

            _, train_loss_value, train_top1_err_value, train_summary, lr = \
                sess.run(fetches=[train_op,
                                  avg_train_loss,
                                  avg_train_top1_error,
                                  train_merge_summary_op,
                                  learning_rate])

            if math.isnan(train_loss_value):
                log.error('Train loss is nan')
                return

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)

            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # validation part
            t_start_val = time.time()

            val_loss_value, val_top1_err_value, val_summary = \
                sess.run(fetches=[avg_val_loss,
                                  avg_val_top1_error,
                                  val_merge_summary_op])

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info(
                    'Epoch_Train: {:d} total_loss= {:6f} top1_error= {:6f} '
                    'lr= {:6f} mean_cost_time= {:5f}s '.format(
                        epoch + 1, train_loss_value, train_top1_err_value, lr,
                        np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                log.info('Epoch_Val: {:d} total_loss= {:6f} top1_error= {:6f}'
                         ' mean_cost_time= {:5f}s '.format(
                             epoch + 1, val_loss_value, val_top1_err_value,
                             np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
    sess.close()

    return
Пример #4
0
def nsfw_eval_dataset(dataset_dir, weights_path, top_k=1):
    """
    Evaluate the nsfw dataset
    :param dataset_dir: The nsfw dataset dir which contains tensorflow records file
    :param weights_path: The pretrained nsfw model weights file path
    :param top_k: calculate the top k accuracy
    :return:
    """
    # set nsfw data feed pipline
    test_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir,
                                                         flags='test')
    prediciton_map = test_dataset.prediction_map

    with tf.device('/gpu:1'):
        # set nsfw classification model
        phase = tf.constant('test', dtype=tf.string)

        # set nsfw net
        nsfw_net = nsfw_classification_net.NSFWNet(phase=phase)

        # compute train loss
        images, labels = test_dataset.inputs(batch_size=CFG.TEST.BATCH_SIZE,
                                             num_epochs=1)

        images_scale = tf.map_fn(fn=scale_image, elems=images, dtype=tf.float32)

        logits = nsfw_net.inference(input_tensor=images,
                                    residual_blocks_nums=CFG.NET.RES_BLOCKS_NUMS,
                                    name='nsfw_cls_model',
                                    reuse=False)

        predictions = tf.nn.softmax(logits)
        top_k_error = calculate_top_k_error(predictions, labels, top_k)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(
        CFG.TRAIN.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()

    # set tensorflow saver
    saver = tf.train.Saver(variables_to_restore)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    avg_prediction_top_k_accuracy = []

    with sess.as_default():

        saver.restore(sess=sess, save_path=weights_path)

        while True:
            try:
                images_vals, predictions_vals, labels_vals, top1_error_val = sess.run(
                    fetches=[images_scale,
                             predictions,
                             labels,
                             top_k_error])

                log.info('\n**************')
                log.info('Test dataset batch size: {:d}'.format(images_vals.shape[0]))
                log.info('---- Sample Id ---- Gt label ---- Prediction ----')

                for index, image in enumerate(images_vals):

                    label_gt = prediciton_map[labels_vals[index]]

                    prediction_score = dict()

                    for score_index, score in enumerate(predictions_vals[index]):
                        prediction_score[prediciton_map[score_index]] = format(score, '.5f')

                    log.info('---- {:d} ---- {:s} ---- {}'.format(index, label_gt, prediction_score))

                    # plt.ion()
                    # plt.figure('source image')
                    # plt.imshow(np.array(image[:, :, (2, 1, 0)], np.uint8))
                    # plt.pause(5.0)
                    # plt.show()
                    # plt.ioff()

                log.info('Top 1 accuracy of this batch is: {:.5f}'.format(1 - top1_error_val))
                avg_prediction_top_k_accuracy.append(1 - top1_error_val)

            except tf.errors.OutOfRangeError as err:
                log.info('Total avg top 1 accuracy is: {:.5f}'.format(np.mean(avg_prediction_top_k_accuracy)))
                break
    return