コード例 #1
0
def train_cnn():
    """Training CNN model."""

    # Load sentences, labels, and training parameters
    logger.info('✔︎ Loading data...')

    logger.info('✔︎ Training data processing...')
    train_data = data_helpers.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes, FLAGS.embedding_dim)

    logger.info('✔︎ Validation data processing...')
    validation_data = \
        data_helpers.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim)

    logger.info('Recommand padding Sequence length is: {}'.format(FLAGS.pad_seq_len))

    logger.info('✔︎ Training data padding...')
    x_train, y_train = data_helpers.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info('✔︎ Validation data padding...')
    x_validation, y_validation = data_helpers.pad_data(validation_data, FLAGS.pad_seq_len)

    y_validation_bind = validation_data.labels_bind

    # Build vocabulary
    VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and cnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            cnn = TextCNN(
                sequence_length=FLAGS.pad_seq_len,
                num_classes=FLAGS.num_classes,
                vocab_size=VOCAB_SIZE,
                fc_hidden_size=FLAGS.fc_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                num_filters=FLAGS.num_filters,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define Training procedure
            # learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate, global_step=cnn.global_step,
            #                                            decay_steps=FLAGS.decay_steps, decay_rate=FLAGS.decay_rate,
            #                                            staircase=True)
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=cnn.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input("☛ Please input the checkpoints model you want to restore, "
                              "it should be like(1490175368): ")  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input('✘ The format of your input is illegal, please re-input: ')
                logger.info('✔︎ The format of your input is legal, now loading to next step...')

                checkpoint_dir = 'runs/' + MODEL + '/checkpoints/'

                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.summary.scalar("loss", cnn.loss)
            # acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)

            # Train Summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

            if FLAGS.train_or_restore == 'R':
                # Load cnn model
                logger.info("✔ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

            current_step = sess.run(cnn.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    cnn.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, cnn.global_step, train_summary_op, cnn.loss], feed_dict)
                time_str = datetime.datetime.now().isoformat()
                logger.info("{}: step {}, loss {:g}"
                                 .format(time_str, step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_validation, y_validation, y_validation_bind, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = data_helpers.batch_iter(
                    list(zip(x_validation, y_validation, y_validation_bind)), 8 * FLAGS.batch_size, FLAGS.num_epochs)
                eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0
                for batch_validation in batches_validation:
                    x_batch_validation, y_batch_validation, y_batch_validation_bind = zip(*batch_validation)
                    feed_dict = {
                        cnn.input_x: x_batch_validation,
                        cnn.input_y: y_batch_validation,
                        cnn.dropout_keep_prob: 1.0,
                        cnn.is_training: False
                    }
                    step, summaries, logits, cur_loss = sess.run(
                        [cnn.global_step, validation_summary_op, cnn.logits, cnn.loss], feed_dict)

                    if FLAGS.use_classbind_or_not == 'Y':
                        predicted_labels = data_helpers.get_label_using_logits_and_classbind(
                            logits, y_batch_validation_bind, top_number=FLAGS.top_num)
                    if FLAGS.use_classbind_or_not == 'N':
                        predicted_labels = data_helpers.get_label_using_logits(logits, top_number=FLAGS.top_num)

                    cur_rec, cur_acc = 0.0, 0.0
                    for index, predicted_label in enumerate(predicted_labels):
                        rec_inc, acc_inc = data_helpers.cal_rec_and_acc(predicted_label, y_batch_validation[index])
                        cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc

                    cur_rec = cur_rec / len(y_batch_validation)
                    cur_acc = cur_acc / len(y_batch_validation)

                    eval_loss, eval_rec, eval_acc, eval_counter = eval_loss + cur_loss, eval_rec + cur_rec, \
                                                                  eval_acc + cur_acc, eval_counter + 1
                    logger.info("✔︎ validation batch {} finished.".format(eval_counter))

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)
                eval_rec = float(eval_rec / eval_counter)
                eval_acc = float(eval_acc / eval_counter)

                return eval_loss, eval_rec, eval_acc

            # Generate batches
            batches_train = data_helpers.batch_iter(
                list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, cnn.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_rec, eval_acc = validation_step(x_validation, y_validation, y_validation_bind,
                                                                    writer=validation_summary_writer)
                    time_str = datetime.datetime.now().isoformat()
                    logger.info("{}: step {}, loss {:g}, rec {:g}, acc {:g}"
                                .format(time_str, current_step, eval_loss, eval_rec, eval_acc))

                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("✔︎ Saved model checkpoint to {}\n".format(path))

    logger.info("✔︎ Done.")
コード例 #2
0
def test_fasttext():
    """Test FASTTEXT model."""

    # Load data
    logger.info("✔ Loading data...")
    logger.info('Recommand padding Sequence length is: {}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Test data processing...')
    test_data = data_helpers.load_data_and_labels(FLAGS.test_data_file,
                                                  FLAGS.num_classes,
                                                  FLAGS.embedding_dim)

    logger.info('✔︎ Test data padding...')
    x_test, y_test = data_helpers.pad_data(test_data, FLAGS.pad_seq_len)
    y_test_bind = test_data.labels_bind

    # Build vocabulary
    VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix(
        VOCAB_SIZE, FLAGS.embedding_dim)

    # Load fasttext model
    logger.info("✔ Loading model...")
    checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]

            # input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]

            # pre-trained_word2vec
            pretrained_embedding = graph.get_operation_by_name(
                "embedding/embedding").outputs[0]

            # Tensors we want to evaluate
            logits = graph.get_operation_by_name("output/logits").outputs[0]

            # Generate batches for one epoch
            batches = data_helpers.batch_iter(list(
                zip(x_test, y_test, y_test_bind)),
                                              FLAGS.batch_size,
                                              1,
                                              shuffle=False)

            # Collect the predictions here
            all_predicitons = []
            eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0
            for batch_test in batches:
                x_batch_test, y_batch_test, y_batch_test_bind = zip(
                    *batch_test)
                feed_dict = {input_x: x_batch_test, dropout_keep_prob: 1.0}
                batch_logits = sess.run(logits, feed_dict)

                if FLAGS.use_classbind_or_not == 'Y':
                    predicted_labels = data_helpers.get_label_using_logits_and_classbind(
                        batch_logits,
                        y_batch_test_bind,
                        top_number=FLAGS.top_num)
                if FLAGS.use_classbind_or_not == 'N':
                    predicted_labels = data_helpers.get_label_using_logits(
                        batch_logits, top_number=FLAGS.top_num)

                all_predicitons = np.append(all_predicitons, predicted_labels)
                cur_rec, cur_acc = 0.0, 0.0
                for index, predicted_label in enumerate(predicted_labels):
                    rec_inc, acc_inc = data_helpers.cal_rec_and_acc(
                        predicted_label, y_batch_test[index])
                    cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc

                cur_rec = cur_rec / len(y_batch_test)
                cur_acc = cur_acc / len(y_batch_test)

                eval_rec, eval_acc, eval_counter = eval_rec + cur_rec, eval_acc + cur_acc, eval_counter + 1
                logger.info(
                    "✔︎ validation batch {} finished.".format(eval_counter))

            eval_rec = float(eval_rec / eval_counter)
            eval_acc = float(eval_acc / eval_counter)
            logger.info("☛ Recall {:g}, Accuracy {:g}".format(
                eval_rec, eval_acc))
            np.savetxt(SAVE_FILE, list(zip(all_predicitons)), fmt='%s')

    logger.info("✔ Done.")
コード例 #3
0
ファイル: test.py プロジェクト: changzeng/Typhoon
def test_cnn():
    """Test CNN model."""

    # Load data
    logger.info("✔ Loading data...")
    logger.info('Recommended padding Sequence length is: {0}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Test data processing...')
    test_data = dh.load_data_and_labels(FLAGS.test_data_file,
                                        FLAGS.embedding_dim)

    logger.info('✔︎ Test data padding...')
    x_test_front, x_test_behind, y_test = dh.pad_data(test_data,
                                                      FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = dh.load_word2vec_matrix(
        VOCAB_SIZE, FLAGS.embedding_dim)

    # Load cnn model
    logger.info("✔ Loading model...")
    checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x_front = graph.get_operation_by_name(
                "input_x_front").outputs[0]
            input_x_behind = graph.get_operation_by_name(
                "input_x_behind").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # pre-trained word2vec
            pretrained_embedding = graph.get_operation_by_name(
                "embedding/embedding").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs
            predictions = graph.get_operation_by_name(
                "output/predictions").outputs[0]
            softmax_scores = graph.get_operation_by_name(
                "output/SoftMax_scores").outputs[0]
            topKPreds = graph.get_operation_by_name(
                "output/topKPreds").outputs[0]
            accuracy = graph.get_operation_by_name(
                "accuracy/accuracy").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = 'output/scores|output/predictions|output/SoftMax_scores|output/topKPreds'

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names.split("|"))
            tf.train.write_graph(output_graph_def,
                                 'graph',
                                 'graph-cnn-{0}.pb'.format(MODEL_LOG),
                                 as_text=False)

            # Generate batches for one epoch
            batches = dh.batch_iter(list(
                zip(x_test_front, x_test_behind, y_test)),
                                    FLAGS.batch_size,
                                    1,
                                    shuffle=False)

            # Collect the predictions here
            all_scores = []
            all_softmax_scores = []
            all_predictions = []
            all_topKPreds = []

            for index, x_test_batch in enumerate(batches):
                x_batch_front, x_batch_behind, y_batch = zip(*x_test_batch)
                feed_dict = {
                    input_x_front: x_batch_front,
                    input_x_behind: x_batch_behind,
                    input_y: y_batch,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }
                batch_scores = sess.run(scores, feed_dict)
                all_scores = np.append(all_scores, batch_scores)

                batch_softmax_scores = sess.run(softmax_scores, feed_dict)
                all_softmax_scores = np.append(all_softmax_scores,
                                               batch_softmax_scores)

                batch_predictions = sess.run(predictions, feed_dict)
                all_predictions = np.concatenate(
                    [all_predictions, batch_predictions])

                batch_topKPreds = sess.run(topKPreds, feed_dict)
                all_topKPreds = np.append(all_topKPreds, batch_topKPreds)

                batch_loss = sess.run(loss, feed_dict)
                batch_acc = sess.run(accuracy, feed_dict)

                logger.info(
                    "✔︎ Test batch {0}: loss {1:g}, accuracy {2:g}.".format(
                        (index + 1), batch_loss, batch_acc))

            os.makedirs(SAVE_DIR)
            np.savetxt(SAVE_DIR + '/result_sub_' + SUBSET + '.txt',
                       list(zip(all_predictions, all_topKPreds)),
                       fmt='%s')

    logger.info("✔ Done.")
コード例 #4
0
def test_cnn():
    """Test CNN model."""

    # Load data
    logger.info("✔ Loading data...")
    logger.info('Recommand padding Sequence length is: {}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Test data processing...')
    test_data = data_helpers.load_data_and_labels(FLAGS.test_data_file,
                                                  FLAGS.embedding_dim)

    logger.info('✔︎ Test data padding...')
    x_test_front, x_test_behind, y_test = data_helpers.pad_data(
        test_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix(
        VOCAB_SIZE, FLAGS.embedding_dim)

    # Load cnn model
    logger.info("✔ Loading model...")
    checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x_front = graph.get_operation_by_name(
                "input_x_front").outputs[0]
            input_x_behind = graph.get_operation_by_name(
                "input_x_behind").outputs[0]

            # input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]

            # pre-trained_word2vec
            pretrained_embedding = graph.get_operation_by_name(
                "embedding/embedding").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs
            predictions = graph.get_operation_by_name(
                "output/predictions").outputs[0]
            softmaxScores = graph.get_operation_by_name(
                "output/softmaxScores").outputs[0]
            topKPreds = graph.get_operation_by_name(
                "output/topKPreds").outputs[0]

            # Generate batches for one epoch
            batches = data_helpers.batch_iter(list(
                zip(x_test_front, x_test_behind)),
                                              FLAGS.batch_size,
                                              1,
                                              shuffle=False)

            # Collect the predictions here
            all_scores = []
            all_softMaxScores = []
            all_predictions = []
            all_topKPreds = []

            for x_test_batch in batches:
                x_batch_front, x_batch_behind = zip(*x_test_batch)
                feed_dict = {
                    input_x_front: x_batch_front,
                    input_x_behind: x_batch_behind,
                    dropout_keep_prob: 1.0
                }
                batch_scores = sess.run(scores, feed_dict)
                all_scores = np.append(all_scores, batch_scores)

                batch_softmax_scores = sess.run(softmaxScores, feed_dict)
                all_softMaxScores = np.append(all_softMaxScores,
                                              batch_softmax_scores)

                batch_predictions = sess.run(predictions, feed_dict)
                all_predictions = np.concatenate(
                    [all_predictions, batch_predictions])

                batch_topKPreds = sess.run(topKPreds, feed_dict)
                all_topKPreds = np.append(all_topKPreds, batch_topKPreds)

            np.savetxt(SAVE_FILE,
                       list(zip(all_predictions, all_topKPreds)),
                       fmt='%s')

    logger.info("✔ Done.")
コード例 #5
0
def train_cnn():
    """Training CNN model."""

    # Load sentences, labels, and training parameters
    logger.info('✔︎ Loading data...')

    logger.info('✔︎ Training data processing...')
    train_data = data_helpers.load_data_and_labels(FLAGS.training_data_file,
                                                   FLAGS.embedding_dim)

    logger.info('✔︎ Validation data processing...')
    validation_data = data_helpers.load_data_and_labels(
        FLAGS.validation_data_file, FLAGS.embedding_dim)

    logger.info('Recommand padding Sequence length is: {}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Training data padding...')
    x_train_front, x_train_behind, y_train = data_helpers.pad_data(
        train_data, FLAGS.pad_seq_len)

    logger.info('✔︎ Validation data padding...')
    x_validation_front, x_validation_behind, y_validation = data_helpers.pad_data(
        validation_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix(
        VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and cnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            cnn = TextCNN(sequence_length=FLAGS.pad_seq_len,
                          num_classes=y_train.shape[1],
                          vocab_size=VOCAB_SIZE,
                          embedding_size=FLAGS.embedding_dim,
                          embedding_type=FLAGS.embedding_type,
                          filter_sizes=list(
                              map(int, FLAGS.filter_sizes.split(","))),
                          num_filters=FLAGS.num_filters,
                          l2_reg_lambda=FLAGS.l2_reg_lambda,
                          pretrained_embedding=pretrained_word2vec_matrix)

            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step,
                                                 name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(
                os.path.join(os.path.curdir, "runs", timestamp))
            logger.info("✔︎ Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.summary.scalar("loss", cnn.loss)
            acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)

            # Train Summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, acc_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge(
                [loss_summary, acc_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)

            # Initialize all variables
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            def train_step(x_batch_front, x_batch_behind, y_batch):
                """A single training step"""
                feed_dict = {
                    cnn.input_x_front: x_batch_front,
                    cnn.input_x_behind: x_batch_behind,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                }
                _, step, summaries, loss, accuracy = sess.run([
                    train_op, global_step, train_summary_op, cnn.loss,
                    cnn.accuracy
                ], feed_dict)
                time_str = datetime.datetime.now().isoformat()
                logger.info("{}: step {}, loss {:g}, acc {:g}".format(
                    time_str, step, loss, accuracy))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_batch_front,
                                x_batch_behind,
                                y_batch,
                                writer=None):
                """Evaluates model on a validation set"""
                feed_dict = {
                    cnn.input_x_front: x_batch_front,
                    cnn.input_x_behind: x_batch_behind,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: 1.0
                }
                step, summaries, scores, predictions, num_correct, \
                loss, accuracy, recall, precision, f1, auc, topKPreds, = sess.run(
                    [global_step, validation_summary_op, cnn.scores, cnn.predictions, cnn.num_correct,
                     cnn.loss, cnn.accuracy, cnn.recall, cnn.precision, cnn.F1, cnn.AUC, cnn.topKPreds], feed_dict)
                time_str = datetime.datetime.now().isoformat()
                logger.info(
                    "{}: step {}, loss {:g}, acc {:g}, "
                    "recall {:g}, precision {:g}, f1 {:g}, AUC {}".format(
                        time_str, step, loss, accuracy, recall, precision, f1,
                        auc))
                if writer:
                    writer.add_summary(summaries, step)

            # Generate batches
            batches = data_helpers.batch_iter(
                list(zip(x_train_front, x_train_behind, y_train)),
                FLAGS.batch_size, FLAGS.num_epochs)

            # Training loop. For each batch...
            for batch in batches:
                x_batch_front, x_batch_behind, y_batch = zip(*batch)
                train_step(x_batch_front, x_batch_behind, y_batch)
                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    validation_step(x_validation_front,
                                    x_validation_behind,
                                    y_validation,
                                    writer=validation_summary_writer)
                if current_step % FLAGS.checkpoint_every == 0:
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info(
                        "✔︎ Saved model checkpoint to {}\n".format(path))

    logger.info("✔︎ Done.")