Ejemplo n.º 1
0
def build_train_savers(variables_to_add=[]):
    """Add variables_to_add to the collection of variables to save.
    Returns:
        train_saver: saver to use to log the training model
        best_saver: saver used to save the best model
    """
    variables = variables_to_save(variables_to_add)
    train_saver = tf.train.Saver(variables, max_to_keep=2)
    best_saver = tf.train.Saver(variables, max_to_keep=1)
    return train_saver, best_saver
def train():
    """Train model.

    Returns:
        best validation error. Save best model"""

    best_validation_error_value = float('inf')

    with tf.Graph().as_default(), tf.device(TRAIN_DEVICE):
        global_step = tf.Variable(0, trainable=False, name="global_step")

        # Get images and labels for CIFAR-10.
        images, _ = DATASET.distorted_inputs(BATCH_SIZE)

        # Build a Graph that computes the reconstructions predictions from the
        # inference model.
        is_training_, reconstructions = MODEL.get(images,
                                                  train_phase=True,
                                                  l2_penalty=L2_PENALTY)

        # display original images next to reconstructed images
        with tf.variable_scope("visualization"):
            grid_side = math.floor(math.sqrt(BATCH_SIZE))
            inputs = put_kernels_on_grid(
                tf.transpose(images, perm=(1, 2, 3, 0))[:, :, :,
                                                        0:grid_side**2],
                grid_side)

            outputs = put_kernels_on_grid(
                tf.transpose(reconstructions,
                             perm=(1, 2, 3, 0))[:, :, :, 0:grid_side**2],
                grid_side)
        tf_log(
            tf.summary.image('input_output',
                             tf.concat(2, [inputs, outputs]),
                             max_outputs=1))

        # Calculate loss.
        loss = MODEL.loss(reconstructions, images)
        # reconstruction error
        error_ = tf.placeholder(tf.float32, shape=())
        error = tf.summary.scalar('error', error_)

        if LR_DECAY:
            # Decay the learning rate exponentially based on the number of steps.
            learning_rate = tf.train.exponential_decay(INITIAL_LR,
                                                       global_step,
                                                       STEPS_PER_DECAY,
                                                       LR_DECAY_FACTOR,
                                                       staircase=True)
        else:
            learning_rate = tf.constant(INITIAL_LR)

        tf_log(tf.summary.scalar('learning_rate', learning_rate))
        train_op = OPTIMIZER.minimize(loss, global_step=global_step)

        # Create the train saver.
        variables = variables_to_save([global_step])
        train_saver = tf.train.Saver(variables, max_to_keep=2)
        # Create the best model saver
        best_saver = tf.train.Saver(variables, max_to_keep=1)

        # read collection after that every op added its own
        # summaries in the train_summaries collection
        train_summaries = tf.summary.merge(
            tf.get_collection_ref(MODEL_SUMMARIES))

        # Build an initialization operation to run below.
        init = tf.variables_initializer(tf.global_variables() +
                                        tf.local_variables())

        # Start running operations on the Graph.
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(init)

            # Start the queue runners with a coordinator
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            if not RESTART:  # continue from the saved checkpoint
                # restore previous session if exists
                checkpoint = tf.train.latest_checkpoint(LOG_DIR)
                if checkpoint:
                    train_saver.restore(sess, checkpoint)
                else:
                    print("[I] Unable to restore from checkpoint")

            train_log = tf.summary.FileWriter(os.path.join(
                LOG_DIR, str(InputType.train)),
                                              graph=sess.graph)
            validation_log = tf.summary.FileWriter(os.path.join(
                LOG_DIR, str(InputType.validation)),
                                                   graph=sess.graph)

            # Extract previous global step value
            old_gs = sess.run(global_step)

            # Restart from where we were
            for step in range(old_gs, MAX_STEPS):
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss],
                                         feed_dict={is_training_: True})
                duration = time.time() - start_time

                if np.isnan(loss_value):
                    print('Model diverged with loss = NaN')
                    break

                # update logs every 10 iterations
                if step % 10 == 0:
                    num_examples_per_step = BATCH_SIZE
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = ('{}: step {}, loss = {:.2f} '
                                  '({:.1f} examples/sec; {:.3f} sec/batch)')
                    print(
                        format_str.format(datetime.now(), step, loss_value,
                                          examples_per_sec, sec_per_batch))
                    # log train error and summaries
                    train_error_summary_line, train_summary_line = sess.run(
                        [error, train_summaries],
                        feed_dict={
                            error_: loss_value,
                            is_training_: True
                        })
                    train_log.add_summary(train_error_summary_line,
                                          global_step=step)
                    train_log.add_summary(train_summary_line, global_step=step)

                # Save the model checkpoint at the end of every epoch
                # evaluate train and validation performance
                if (step > 0 and step % STEPS_PER_EPOCH
                        == 0) or (step + 1) == MAX_STEPS:
                    checkpoint_path = os.path.join(LOG_DIR, 'model.ckpt')
                    train_saver.save(sess, checkpoint_path, global_step=step)

                    # validation error
                    validation_error_value = evaluate.error(
                        LOG_DIR,
                        MODEL,
                        DATASET,
                        InputType.validation,
                        device=EVAL_DEVICE)

                    summary_line = sess.run(
                        error, feed_dict={error_: validation_error_value})
                    validation_log.add_summary(summary_line, global_step=step)

                    print('{} ({}): train error = {} validation error = {}'.
                          format(datetime.now(), int(step / STEPS_PER_EPOCH),
                                 loss_value, validation_error_value))
                    if validation_error_value < best_validation_error_value:
                        best_validation_error_value = validation_error_value
                        best_saver.save(sess,
                                        os.path.join(BEST_MODEL_DIR,
                                                     'model.ckpt'),
                                        global_step=step)
            # end of for

            validation_log.close()
            train_log.close()

            # When done, ask the threads to stop.
            coord.request_stop()
            # Wait for threads to finish.
            coord.join(threads)
    return best_validation_error_value