Ejemplo n.º 1
0
def train():
    data = Data()
    data.read_data(filepath='data/train.csv',
                   train_size=TRAIN_SIZE,
                   validation_size=VALIDATION_SIZE,
                   convert_to_one_hot=True)
    #data.train.display_digit()
    sess = tf.InteractiveSession()

    def variable_summaries(var):
        """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
        with tf.name_scope('summaries'):
            mean = tf.reduce_mean(var)
            tf.summary.scalar('mean', mean)
            with tf.name_scope('stddev'):
                stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
            tf.summary.scalar('stddev', stddev)
            tf.summary.scalar('max', tf.reduce_max(var))
            tf.summary.scalar('min', tf.reduce_min(var))
            tf.summary.histogram('histogram', var)

    with tf.name_scope('input'):
        input_layer = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE])
        output_layer = tf.placeholder(tf.float32, shape=[None, N_CLASSES])

    with tf.name_scope('reshape_input'):
        image_shaped_input = tf.reshape(input_layer, [-1, 28, 28, 1])
        tf.summary.image('input', image_shaped_input)

    def weight_variable(shape):
        initial = tf.truncated_normal(shape, stddev=0.1)
        return tf.Variable(initial)

    def bias_variable(shape):
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial)

    def convolution_2d(input_tensor,
                       input_dimension,
                       nb_filter,
                       filter_size,
                       name,
                       activation=tf.nn.relu):
        with tf.name_scope(name):
            with tf.name_scope('weights'):
                weights = weight_variable(
                    [filter_size, filter_size, input_dimension, nb_filter])
                variable_summaries(weights)
            with tf.name_scope('biases'):
                biases = bias_variable([nb_filter])
                variable_summaries(biases)
            with tf.name_scope('preactivation'):
                preactivate = conv2d(input_tensor, weights) + biases  # !!!
                tf.summary.histogram('pre-activation', preactivate)
            activations = activation(preactivate, name='activation')
            tf.summary.histogram('activations', activations)
            return activations

    def conv2d(input_tensor, weights):
        return tf.nn.conv2d(input_tensor,
                            weights,
                            strides=[1, 2, 2, 1],
                            padding='SAME')

    def max_pool_2d(input_tensor, kernel_size, name):
        with tf.name_scope(name):
            return tf.nn.max_pool(
                input_tensor,
                ksize=[1, 2, 2, 1],  # kernel size?
                strides=[1, 2, 2, 1],
                padding='SAME')

    def fully_connected(input_tensor, image_size, nb_filter, n_units, name,
                        activation):
        with tf.name_scope(name):
            with tf.name_scope('weights'):
                weights = weight_variable(
                    [image_size * image_size * nb_filter, n_units])
                variable_summaries(weights)
            with tf.name_scope('biases'):
                biases = bias_variable([n_units])
                variable_summaries(biases)
            with tf.name_scope('preactivation'):
                input_tensor_flat = tf.reshape(
                    input_tensor, [-1, image_size * image_size * nb_filter])
                preactivate = tf.matmul(input_tensor_flat,
                                        weights) + biases  # same as convo
                tf.summary.histogram('pre-activation', preactivate)
            if activation == 'NONE':
                return preactivate
            else:
                activations = activation(preactivate, name='activation')
                tf.summary.histogram('activations', activations)
                return activations

    with tf.name_scope('neural_network_architecture'):
        conv_1 = convolution_2d(image_shaped_input,
                                1,
                                nb_filter=16,
                                filter_size=3,
                                activation=tf.nn.relu,
                                name='convolutional_layer_1')
        conv_2 = convolution_2d(conv_1,
                                16,
                                nb_filter=32,
                                filter_size=3,
                                activation=tf.nn.relu,
                                name='convolutional_layer_2')
        pool_1 = max_pool_2d(conv_2, kernel_size=2, name='pool_layer_1')
        conv_3 = convolution_2d(pool_1,
                                32,
                                nb_filter=64,
                                filter_size=3,
                                activation=tf.nn.relu,
                                name='convolutional_layer_3')
        conv_4 = convolution_2d(conv_3,
                                64,
                                nb_filter=128,
                                filter_size=3,
                                activation=tf.nn.relu,
                                name='convolutional_layer_4')
        pool_2 = max_pool_2d(conv_4, kernel_size=2, name='pool_layer_2')
        fc_1 = fully_connected(pool_2,
                               1,
                               nb_filter=128,
                               n_units=2048,
                               activation=tf.nn.relu,
                               name='fully_connected_1')
        fc_2 = fully_connected(fc_1,
                               1,
                               nb_filter=2048,
                               n_units=512,
                               activation=tf.nn.relu,
                               name='fully_connected_2')

        with tf.name_scope('dropout'):
            keep_prob = tf.placeholder(tf.float32)
            tf.summary.scalar('dropout_keep_probability', keep_prob)
            dropped = tf.nn.dropout(fc_2, keep_prob)

        y = fully_connected(dropped,
                            1,
                            nb_filter=512,
                            n_units=10,
                            activation=tf.nn.softmax,
                            name='fully_connected_3')

    with tf.name_scope('loss_function'):
        diff = tf.nn.softmax_cross_entropy_with_logits(labels=output_layer,
                                                       logits=y)
        with tf.name_scope('total'):
            cross_entropy = tf.reduce_mean(diff)
    tf.summary.scalar('cross_entropy', cross_entropy)

    with tf.name_scope('optimizer'):
        train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(
            cross_entropy)

    with tf.name_scope('accuracy'):
        with tf.name_scope('correct_prediction'):
            correct_prediction = tf.equal(tf.argmax(y, 1),
                                          tf.argmax(output_layer, 1))
        with tf.name_scope('accuracy'):
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', accuracy)

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(LOG_DIR + '/train', sess.graph)
    test_writer = tf.summary.FileWriter(LOG_DIR + '/test')
    tf.global_variables_initializer().run()

    print("\nTraining the network...")
    t = trange(EPOCHS * data.train.images.shape[0] // BATCH_SIZE)
    for i in t:
        # selecting a batch
        batch_x, batch_y = data.train.batch(BATCH_SIZE)
        # evaluating
        if i % 10 == 0:
            summary, acc = sess.run(
                [merged, accuracy],
                feed_dict={
                    input_layer: data.validation.images,
                    output_layer: data.validation.labels,
                    keep_prob: 1.0
                })
            test_writer.add_summary(summary, i)
            print('Accuracy at step %s: %s' % (i, acc))
        else:  # Record train set summaries, and train
            if i % 100 == 99:  # Record execution stats
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                summary, _ = sess.run(
                    [merged, train_step],
                    feed_dict={
                        input_layer: data.train.images,
                        output_layer: data.train.labels,
                        keep_prob: DROP_OUT
                    },
                    options=run_options,
                    run_metadata=run_metadata)
                train_writer.add_run_metadata(run_metadata, 'step%03d' % i)
                train_writer.add_summary(summary, i)
                print('Adding run metadata for', i)
            else:  # Record a summary
                summary, _ = sess.run(
                    [merged, train_step],
                    feed_dict={
                        input_layer: data.train.images,
                        output_layer: data.train.labels,
                        keep_prob: DROP_OUT
                    })
                train_writer.add_summary(summary, i)
    train_writer.close()
    test_writer.close()

    def getActivations(layer, stimuli):
        units = sess.run(layer,
                         feed_dict={
                             input_layer: np.reshape(stimuli, [1, 784],
                                                     order='F'),
                             keep_prob: 1.0
                         })
        plotNNFilter(units)

    def plotNNFilter(units):
        filters = units.shape[3]
        plt.figure(1, figsize=(20, 20))
        n_columns = 6
        n_rows = math.ceil(filters / n_columns) + 1
        for i in range(filters):
            plt.subplot(n_rows, n_columns, i + 1)
            plt.title('Filter ' + str(i))
            plt.imshow(units[0, :, :, i], interpolation="nearest", cmap="gray")

    imageToUse = data.train.images[0]
    data.train.display_digit()
    plt.imshow(np.reshape(imageToUse, [28, 28]),
               interpolation="nearest",
               cmap="gray")
    plt.show()
    #getActivations(conv_1, imageToUse)
    #getActivations(conv_2, imageToUse)
    #getActivations(conv_3, imageToUse)
    getActivations(conv_4, imageToUse)
    print('h')
    plt.show()