Ejemplo n.º 1
0
def predict(sess, model, input_seqs, reset_seqs):
    """ Compute prediction sequences from input sequences.

    Args:
        sess: A Session.
        model: An LSTMModel.
        input_seqs: A list of input sequences, each a float32 NumPy array with
            shape `[duration, input_size]`.
        reset_seqs: A list of reset sequences, each a bool NumPy array with
            shape `[duration, 1]`.

    Returns:
        A list of prediction sequences, each a NumPy array with shape
        `[duration, 1]`, containing predicted labels for each time step.
    """

    batch_size = len(input_seqs)
    seq_durations = [len(seq) for seq in input_seqs]
    input_sweep, reset_sweep = data.sweep_generator(
        [input_seqs, reset_seqs], batch_size=batch_size).__next__()

    logit_sweep = sess.run(model.logits, feed_dict={model.inputs: input_sweep,
                                                    model.resets: reset_sweep,
                                                    model.training: False})

    logit_seqs = [seq[:duration]
                  for (seq, duration) in zip(logit_sweep, seq_durations)]
    prediction_seqs = [np.argmax(seq, axis=1).reshape(-1, 1)
                       for seq in logit_seqs]

    return prediction_seqs
def predict(sess, model, input_seqs, reset_seqs):
    """ Compute prediction sequences from input sequences.

    Args:
        sess: A Session.
        model: An LSTMModel.
        input_seqs: A list of input sequences, each a float32 NumPy array with
            shape `[duration, input_size]`.
        reset_seqs: A list of reset sequences, each a bool NumPy array with
            shape `[duration, 1]`.

    Returns:
        A list of prediction sequences, each a NumPy array with shape
        `[duration, 1]`, containing predicted labels for each time step.
    """

    batch_size = len(input_seqs)
    seq_durations = [len(seq) for seq in input_seqs]
    input_sweep, reset_sweep = data.sweep_generator(
        [input_seqs, reset_seqs], batch_size=batch_size).next()

    logit_sweep = sess.run(model.logits, feed_dict={model.inputs: input_sweep,
                                                    model.resets: reset_sweep,
                                                    model.training: False})

    logit_seqs = [seq[:duration]
                  for (seq, duration) in zip(logit_sweep, seq_durations)]
    prediction_seqs = [np.argmax(seq, axis=1).reshape(-1, 1)
                       for seq in logit_seqs]

    return prediction_seqs
Ejemplo n.º 3
0
def test_predict(sess, model, input_seqs, reset_seqs, sample_times,
                 test_label_seqs):
    """ Compute prediction sequences from input sequences.

    Args:
        sess: A Session.
        model: An LSTMModel.
        input_seqs: A list of input sequences, each a float32 NumPy array with
            shape `[duration, input_size]`.
        reset_seqs: A list of reset sequences, each a bool NumPy array with
            shape `[duration, 1]`.

    Returns:
        A list of prediction sequences, each a NumPy array with shape
        `[duration, 1]`, containing predicted labels for each time step.
    """
    err = 0.0
    batch_size = len(input_seqs)
    seq_durations = [len(seq) for seq in input_seqs]

    print('batch_size,seq_durations:', batch_size, seq_durations)
    input_sweep, reset_sweep = data.sweep_generator(
        [input_seqs, reset_seqs], batch_size=batch_size).next()
    print('input_sweep.shape:', input_sweep.shape)
    # init empty predictions
    y_ = np.zeros((sample_times, batch_size, input_sweep.shape[1]))
    for sample_id in range(sample_times):
        logit_sweep = sess.run(model.logits,
                               feed_dict={
                                   model.inputs: input_sweep,
                                   model.resets: reset_sweep,
                                   model.training: False
                               })

        logit_seqs = [
            seq[:duration]
            for (seq, duration) in zip(logit_sweep, seq_durations)
        ]
        prediction_seqs = [
            np.argmax(seq, axis=1).reshape(-1, 1) for seq in logit_seqs
        ]
        print(np.asarray(logit_sweep).shape, np.asarray(prediction_seqs).shape)
        y_[sample_id] = np.asarray(np.argmax(logit_sweep, axis=1))
    mean_y = y_.mean(axis=0)
    # evaluate against labels
    y = test_label_seqs
    # compute error
    err += np.count_nonzero(
        np.not_equal(mean_y.argmax(axis=1), y.argmax(axis=1)))

    return prediction_seqs, err
Ejemplo n.º 4
0
def predict(sess, model, input_seqs, reset_seqs):
    """ Compute prediction sequences from input sequences.

    Args:
        sess: A Session.
        model: An LSTMModel.
        input_seqs: A list of input sequences, each a float32 NumPy array with
            shape `[duration, input_size]`.
        reset_seqs: A list of reset sequences, each a bool NumPy array with
            shape `[duration, 1]`.

    Returns:
        A list of prediction sequences, each a NumPy array with shape
        `[duration, 1]`, containing predicted labels for each time step.
    """
    batch_size = len(input_seqs)
    seq_durations = [len(seq) for seq in input_seqs]
    input_sweep, reset_sweep = data.sweep_generator(
        [input_seqs, reset_seqs], batch_size=batch_size).next()
    logit_sweep, softmax = sess.run(
        [model.logits, model.softmax],
        feed_dict={
            model.inputs: input_sweep,
            model.resets: reset_sweep,
            model.training: False
        })

    softmax_dur = [
        seq[:duration] for (seq, duration) in zip(softmax, seq_durations)
    ]
    logit_seqs = [
        seq[:duration] for (seq, duration) in zip(logit_sweep, seq_durations)
    ]
    #prediction_seqs = [np.argmax(seq, axis=1).reshape(-1, 1)
    #                   for seq in logit_seqs]
    prediction_seqs = [(-seq).argsort(axis=1)[:3].reshape(-1, 3)
                       for seq in logit_seqs]
    print("prediction_seqs:", prediction_seqs[0].shape)
    print("logit_seqs:", logit_seqs[0].shape)
    return (prediction_seqs, softmax_dur)
Ejemplo n.º 5
0
def predict(sess, input_seqs, reset_seqs,path,entropy=True):
    saver = tf.train.import_meta_graph(path + 'model.ckpt.meta')
    saver.restore(sess, path + 'model.ckpt')
    if entropy:
        feature_map = tf.get_default_graph().get_tensor_by_name("logits/Softmax:0")
    else:
        feature_map = tf.get_default_graph().get_tensor_by_name("logits/logits:0")
    graph = tf.get_default_graph()
    inputs = graph.get_tensor_by_name("inputs:0")
    resets = graph.get_tensor_by_name("resets:0")
    training = graph.get_tensor_by_name("training:0")

    batch_size = len(input_seqs)
    seq_durations = [len(seq) for seq in input_seqs]
    input_sweep, reset_sweep = data.sweep_generator(
            [input_seqs, reset_seqs], batch_size=batch_size).next()
    softmax = sess.run(feature_map, feed_dict={inputs: input_sweep, resets: reset_sweep, training: False})
    softmax_dur = [seq[:duration]
                   for (seq, duration) in zip(softmax, seq_durations)]
    if not entropy:
        prediction_seqs = [np.argmax(seq, axis=1).reshape(-1, 1)
                       for seq in softmax_dur]
        softmax_dur = prediction_seqs
    return softmax_dur
Ejemplo n.º 6
0
def train(sess, model, optimizer, log_dir, batch_size, num_sweeps_per_summary,
          num_sweeps_per_save, train_input_seqs, train_reset_seqs,
          train_label_seqs, test_input_seqs, test_reset_seqs, test_label_seqs):
    """ Train a model and export summaries.

    `log_dir` will be *replaced* if it already exists, so it certainly
    shouldn't be anything generic like `/home/user`.

    Args:
        sess: A TensorFlow `Session`.
        model: An `LSTMModel`.
        optimizer: An `Optimizer`.
        log_dir: A string. The full path to the log directory.
        batch_size: An integer. The number of sequences in a batch.
        num_sweeps_per_summary: An integer. The number of sweeps between
            summaries.
        num_sweeps_per_save: An integer. The number of sweeps between saves.
        train_input_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, input_size]`.
        train_reset_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        train_label_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        test_input_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, input_size]`.
        test_reset_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        test_label_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
    """

    ema = tf.train.ExponentialMovingAverage(decay=0.5)
    update_train_loss_ema = ema.apply([model.loss])
    train_loss_ema = ema.average(model.loss)
    tf.summary.scalar('train_loss_ema', train_loss_ema)

    train_accuracy = tf.placeholder(tf.float32, name='train_accuracy')
    train_edit_dist = tf.placeholder(tf.float32, name='train_edit_dist')
    test_accuracy = tf.placeholder(tf.float32, name='test_accuracy')
    test_edit_dist = tf.placeholder(tf.float32, name='test_edit_dist')
    #values = [train_accuracy, train_edit_dist, test_accuracy, test_edit_dist]
    #tags = [value.op.name for value in values]

    tf.summary.scalar('learning_rate', optimizer.learning_rate)
    for value in [
            train_accuracy, train_edit_dist, test_accuracy, test_edit_dist
    ]:
        tf.summary.scalar(value.op.name, value)

    #tf.summary.scalar(tags, tf.stack(values))

    summary_op = tf.summary.merge_all()

    if os.path.exists(log_dir):
        shutil.rmtree(log_dir)
    summary_writer = tf.summary.FileWriter(logdir=log_dir, graph=sess.graph)
    saver = tf.train.Saver()

    sess.run(tf.global_variables_initializer())

    num_sweeps_visited = 0
    start_time = time.time()
    train_gen = data.sweep_generator(
        [train_input_seqs, train_reset_seqs, train_label_seqs],
        batch_size=batch_size,
        shuffle=True,
        num_sweeps=None)
    while num_sweeps_visited <= optimizer.num_train_sweeps:

        if num_sweeps_visited % num_sweeps_per_summary == 0:

            train_prediction_seqs = models.predict(sess, model,
                                                   train_input_seqs,
                                                   train_reset_seqs)
            train_accuracy_, train_edit_dist_ = metrics.compute_metrics(
                train_prediction_seqs, train_label_seqs)
            test_prediction_seqs = models.predict(sess, model, test_input_seqs,
                                                  test_reset_seqs)
            test_accuracy_, test_edit_dist_ = metrics.compute_metrics(
                test_prediction_seqs, test_label_seqs)
            summary = sess.run(summary_op,
                               feed_dict={
                                   train_accuracy: train_accuracy_,
                                   train_edit_dist: train_edit_dist_,
                                   test_accuracy: test_accuracy_,
                                   test_edit_dist: test_edit_dist_
                               })
            summary_writer.add_summary(summary, global_step=num_sweeps_visited)

            status_path = os.path.join(log_dir, 'status.txt')
            with open(status_path, 'w') as f:
                line = '%05.1f      ' % ((time.time() - start_time) / 60)
                line += '%04d      ' % num_sweeps_visited
                line += '%.6f  %08.3f     ' % (train_accuracy_,
                                               train_edit_dist_)
                line += '%.6f  %08.3f     ' % (test_accuracy_, test_edit_dist_)
                print(line, file=f)

            label_path = os.path.join(log_dir, 'test_label_seqs.pkl')
            with open(label_path, 'wb') as f:
                cPickle.dump(test_label_seqs, f)

            pred_path = os.path.join(log_dir, 'test_prediction_seqs.pkl')
            with open(pred_path, 'wb') as f:
                cPickle.dump(test_prediction_seqs, f)

            vis_filename = 'test_visualizations_%06d.png' % num_sweeps_visited
            vis_path = os.path.join(log_dir, vis_filename)
            fig, axes = data.visualize_predictions(test_prediction_seqs,
                                                   test_label_seqs,
                                                   model.target_size)
            axes[0].set_title(line)
            plt.tight_layout()
            plt.savefig(vis_path)
            plt.close(fig)

        if num_sweeps_visited % num_sweeps_per_save == 0:
            saver.save(sess, os.path.join(log_dir, 'model.ckpt'))

        train_inputs, train_resets, train_labels = train_gen.__next__()
        # We squeeze here because otherwise the targets would have shape
        # [batch_size, duration, 1, num_classes].
        train_targets = data.one_hot(train_labels, model.target_size)
        train_targets = train_targets.squeeze(axis=2)

        _, _, num_sweeps_visited = sess.run(
            [
                optimizer.optimize_op, update_train_loss_ema,
                optimizer.num_sweeps_visited
            ],
            feed_dict={
                model.inputs: train_inputs,
                model.resets: train_resets,
                model.targets: train_targets,
                model.training: True
            })
Ejemplo n.º 7
0
def train(sess, model, optimizer, log_dir, batch_size, num_sweeps_per_summary,
          num_sweeps_per_save, train_input_seqs, train_reset_seqs,
          train_label_seqs, test_input_seqs, test_reset_seqs, test_label_seqs,
          args):
    """ Train a model and export summaries.

    `log_dir` will be *replaced* if it already exists, so it certainly
    shouldn't be anything generic like `/home/user`.

    Args:
        sess: A TensorFlow `Session`.
        model: An `LSTMModel`.
        optimizer: An `Optimizer`.
        log_dir: A string. The full path to the log directory.
        batch_size: An integer. The number of sequences in a batch.
        num_sweeps_per_summary: An integer. The number of sweeps between
            summaries.
        num_sweeps_per_save: An integer. The number of sweeps between saves.
        train_input_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, input_size]`.
        train_reset_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        train_label_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        test_input_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, input_size]`.
        test_reset_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        test_label_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        args: An object containing processed arguments as attributes.
    """
    #print(test_label_seqs)
    ema = tf.train.ExponentialMovingAverage(decay=0.5)  #0.5
    update_train_loss_ema = ema.apply([model.loss])
    train_loss_ema = ema.average(model.loss)
    tf.summary.scalar('train_loss_ema', train_loss_ema)

    train_accuracy = tf.placeholder(tf.float32, name='train_accuracy')
    train_edit_dist = tf.placeholder(tf.float32, name='train_edit_dist')
    test_accuracy = tf.placeholder(tf.float32, name='test_accuracy')
    test_edit_dist = tf.placeholder(tf.float32, name='test_edit_dist')
    #values = [train_accuracy, train_edit_dist, test_accuracy, test_edit_dist]
    #tags = [value.op.name for value in values]

    tf.summary.scalar('learning_rate', optimizer.learning_rate)
    for value in [
            train_accuracy, train_edit_dist, test_accuracy, test_edit_dist
    ]:
        tf.summary.scalar(value.op.name, value)

    #tf.summary.scalar(tags, tf.stack(values))

    summary_op = tf.summary.merge_all()

    if os.path.exists(log_dir):
        shutil.rmtree(log_dir)
    summary_writer = tf.summary.FileWriter(logdir=log_dir, graph=sess.graph)
    saver = tf.train.Saver()

    sess.run(tf.global_variables_initializer())

    num_sweeps_visited = 0
    start_time = time.time()
    train_gen = data.sweep_generator(
        [train_input_seqs, train_reset_seqs, train_label_seqs],
        batch_size=batch_size,
        shuffle=True,
        num_sweeps=None)
    while num_sweeps_visited <= optimizer.num_train_sweeps:

        if num_sweeps_visited % num_sweeps_per_summary == 0:
            test_prediction_seqs = []
            train_prediction_seqs, logits = models.predict(
                sess, model, train_input_seqs, train_reset_seqs)
            train_accuracy_, train_edit_dist_, train_confusion_matrix = metrics.compute_metrics(
                train_prediction_seqs, train_label_seqs, log_dir, 'train')
            print('test_input_seqs:', len(test_input_seqs))
            err = 0.0
            # init empty predictions
            no_of_samples = sum([len(seq) for seq in test_input_seqs])
            entropy_matrix = np.zeros(
                (args.sample_times, no_of_samples, 10))  #batch_size
            entropy_matrix_1 = np.zeros(
                (args.sample_times, no_of_samples, 1))  # batch_size
            for sample_id in range(50):
                test_prediction_seqs, softmax_seqs = models.predict(
                    sess, model, test_input_seqs, test_reset_seqs)
                #print('softmax_dur:', softmax_seqs)
                entropy_matrix[sample_id] = np.vstack(softmax_seqs)
                entropy_matrix_1[sample_id] = np.vstack(test_prediction_seqs)
            ''' #Variation Ratio
            value,count = mode(entropy_matrix_1,axis=0)
            print('count.shape:', value[0,200,0], count[0,200,0])
            varition_ratio=1-count/50.0
            print(varition_ratio.shape)
            entropy=np.squeeze(varition_ratio)
            entropy[entropy > 0.5] = 222
            entropy[entropy < 0.5] = 111'''
            #MC Dropout - Entropy

            #print('entropy_matrix.shape:',entropy_matrix.shape,entropy_matrix[0,0,:])
            entropy_matrix_mean = np.mean(entropy_matrix, axis=0)
            #print('entropy_matrix_mean.shape:', entropy_matrix_mean.shape,entropy_matrix_mean[0,:])
            entropy_log = np.log2(entropy_matrix_mean,
                                  where=(entropy_matrix_mean != 0.0))
            #print('entropy_log:',entropy_log[0,:])
            entropy_log_mul = entropy_matrix_mean * entropy_log
            #print('entropy_log_mul.shape:', entropy_log_mul.shape, entropy_log_mul[0, :])
            entropy = np.sum(entropy_log_mul, axis=1) * -1
            #print('entropy.shape:', entropy.shape,entropy[0])
            normalized_entropy = (entropy - np.mean(entropy, axis=0)) / np.std(
                entropy, axis=0)

            test_accuracy_, test_edit_dist_, test_confusion_matrix = metrics.compute_metrics(
                test_prediction_seqs, test_label_seqs, log_dir, 'test')
            summary = sess.run(summary_op,
                               feed_dict={
                                   train_accuracy: train_accuracy_,
                                   train_edit_dist: train_edit_dist_,
                                   test_accuracy: test_accuracy_,
                                   test_edit_dist: test_edit_dist_
                               })
            print('kris_po:: num_sweeps_visited:', num_sweeps_visited)
            summary_writer.add_summary(summary, global_step=num_sweeps_visited)
            summary_writer.add_summary(train_confusion_matrix,
                                       global_step=num_sweeps_visited)
            summary_writer.add_summary(test_confusion_matrix,
                                       global_step=num_sweeps_visited)

            status_path = os.path.join(log_dir, 'status.txt')
            with open(status_path, 'w') as f:
                line = '%05.1f      ' % ((time.time() - start_time) / 60)
                line += '%04d      ' % num_sweeps_visited
                line += '%.6f  %08.3f     ' % (train_accuracy_,
                                               train_edit_dist_)
                line += '%.6f  %08.3f     ' % (test_accuracy_, test_edit_dist_)
                print(line, file=f)

            label_path = os.path.join(log_dir, 'test_label_seqs.pkl')
            with open(label_path, 'wb') as f:
                cPickle.dump(test_label_seqs, f)

            pred_path = os.path.join(log_dir, 'test_prediction_seqs.pkl')
            with open(pred_path, 'wb') as f:
                cPickle.dump(test_prediction_seqs, f)

            if num_sweeps_visited == 1197:
                vis_filename = 'test_visualizations_%06d.png' % num_sweeps_visited
                vis_path = os.path.join(log_dir, vis_filename)
                fig, axes = data.visualize_predictions(test_prediction_seqs,
                                                       test_label_seqs,
                                                       model.target_size,
                                                       normalized_entropy)
                axes[0].set_title(line)
                plt.tight_layout()
                plt.savefig(vis_path)
                plt.close(fig)

        if num_sweeps_visited % num_sweeps_per_save == 0:
            saver.save(sess, os.path.join(log_dir, 'model.ckpt'))

        train_inputs, train_resets, train_labels = train_gen.next()
        # We squeeze here because otherwise the targets would have shape
        # [batch_size, duration, 1, num_classes].
        train_targets = data.one_hot(train_labels, model.target_size)
        train_targets = train_targets.squeeze(axis=2)

        _, _, num_sweeps_visited = sess.run(
            [
                optimizer.optimize_op, update_train_loss_ema,
                optimizer.num_sweeps_visited
            ],
            feed_dict={
                model.inputs: train_inputs,
                model.resets: train_resets,
                model.targets: train_targets,
                model.training: True
            })
def train(sess, model, optimizer, log_dir, batch_size, num_sweeps_per_summary,
          num_sweeps_per_save, train_input_seqs, train_reset_seqs,
          train_label_seqs, test_input_seqs, test_reset_seqs, test_label_seqs):
    """ Train a model and export summaries.

    `log_dir` will be *replaced* if it already exists, so it certainly
    shouldn't be anything generic like `/home/user`.

    Args:
        sess: A TensorFlow `Session`.
        model: An `LSTMModel`.
        optimizer: An `Optimizer`.
        log_dir: A string. The full path to the log directory.
        batch_size: An integer. The number of sequences in a batch.
        num_sweeps_per_summary: An integer. The number of sweeps between
            summaries.
        num_sweeps_per_save: An integer. The number of sweeps between saves.
        train_input_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, input_size]`.
        train_reset_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        train_label_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        test_input_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, input_size]`.
        test_reset_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
        test_label_seqs: A list of 2-D NumPy arrays, each with shape
            `[duration, 1]`.
    """

    ema = tf.train.ExponentialMovingAverage(decay=0.5)
    update_train_loss_ema = ema.apply([model.loss])
    train_loss_ema = ema.average(model.loss)
    tf.scalar_summary('train_loss_ema', train_loss_ema)

    train_accuracy = tf.placeholder(tf.float32, name='train_accuracy')
    train_edit_dist = tf.placeholder(tf.float32, name='train_edit_dist')
    test_accuracy = tf.placeholder(tf.float32, name='test_accuracy')
    test_edit_dist = tf.placeholder(tf.float32, name='test_edit_dist')
    values = [train_accuracy, train_edit_dist, test_accuracy, test_edit_dist]
    tags = [value.op.name for value in values]
    tf.scalar_summary('learning_rate', optimizer.learning_rate)
    tf.scalar_summary(tags, tf.pack(values))

    summary_op = tf.merge_all_summaries()

    if os.path.exists(log_dir):
        shutil.rmtree(log_dir)
    summary_writer = tf.train.SummaryWriter(logdir=log_dir, graph=sess.graph)
    saver = tf.train.Saver()

    sess.run(tf.initialize_all_variables())

    num_sweeps_visited = 0
    start_time = time.time()
    train_gen = data.sweep_generator(
        [train_input_seqs, train_reset_seqs, train_label_seqs],
        batch_size=batch_size, shuffle=True, num_sweeps=None)
    while num_sweeps_visited <= optimizer.num_train_sweeps:

        if num_sweeps_visited % num_sweeps_per_summary == 0:

            train_prediction_seqs = models.predict(
                sess, model, train_input_seqs, train_reset_seqs)
            train_accuracy_, train_edit_dist_ = metrics.compute_metrics(
                train_prediction_seqs, train_label_seqs)
            test_prediction_seqs = models.predict(
                sess, model, test_input_seqs, test_reset_seqs)
            test_accuracy_, test_edit_dist_ = metrics.compute_metrics(
                test_prediction_seqs, test_label_seqs)
            summary = sess.run(summary_op,
                               feed_dict={train_accuracy: train_accuracy_,
                                          train_edit_dist: train_edit_dist_,
                                          test_accuracy: test_accuracy_,
                                          test_edit_dist: test_edit_dist_})
            summary_writer.add_summary(summary, global_step=num_sweeps_visited)

            status_path = os.path.join(log_dir, 'status.txt')
            with open(status_path, 'w') as f:
                line = '%05.1f      ' % ((time.time() - start_time)/60)
                line += '%04d      ' % num_sweeps_visited
                line += '%.6f  %08.3f     ' % (train_accuracy_,
                                               train_edit_dist_)
                line += '%.6f  %08.3f     ' % (test_accuracy_,
                                               test_edit_dist_)
                print(line, file=f)

            label_path = os.path.join(log_dir, 'test_label_seqs.pkl')
            with open(label_path, 'w') as f:
                cPickle.dump(test_label_seqs, f)

            pred_path = os.path.join(log_dir, 'test_prediction_seqs.pkl')
            with open(pred_path, 'w') as f:
                cPickle.dump(test_prediction_seqs, f)

            vis_filename = 'test_visualizations_%06d.png' % num_sweeps_visited
            vis_path = os.path.join(log_dir, vis_filename)
            fig, axes = data.visualize_predictions(test_prediction_seqs,
                                                   test_label_seqs,
                                                   model.target_size)
            axes[0].set_title(line)
            plt.tight_layout()
            plt.savefig(vis_path)
            plt.close(fig)

        if num_sweeps_visited % num_sweeps_per_save == 0:
            saver.save(sess, os.path.join(log_dir, 'model.ckpt'))

        train_inputs, train_resets, train_labels = train_gen.next()
        # We squeeze here because otherwise the targets would have shape
        # [batch_size, duration, 1, num_classes].
        train_targets = data.one_hot(train_labels, model.target_size)
        train_targets = train_targets.squeeze(axis=2)

        _, _, num_sweeps_visited = sess.run(
            [optimizer.optimize_op,
             update_train_loss_ema,
             optimizer.num_sweeps_visited],
            feed_dict={model.inputs: train_inputs,
                       model.resets: train_resets,
                       model.targets: train_targets,
                       model.training: True})