def build_train_savers(variables_to_add=[]): """Add variables_to_add to the collection of variables to save. Returns: train_saver: saver to use to log the training model best_saver: saver used to save the best model """ variables = variables_to_save(variables_to_add) train_saver = tf.train.Saver(variables, max_to_keep=2) best_saver = tf.train.Saver(variables, max_to_keep=1) return train_saver, best_saver
def train(): """Train model. Returns: best validation error. Save best model""" best_validation_error_value = float('inf') with tf.Graph().as_default(), tf.device(TRAIN_DEVICE): global_step = tf.Variable(0, trainable=False, name="global_step") # Get images and labels for CIFAR-10. images, _ = DATASET.distorted_inputs(BATCH_SIZE) # Build a Graph that computes the reconstructions predictions from the # inference model. is_training_, reconstructions = MODEL.get(images, train_phase=True, l2_penalty=L2_PENALTY) # display original images next to reconstructed images with tf.variable_scope("visualization"): grid_side = math.floor(math.sqrt(BATCH_SIZE)) inputs = put_kernels_on_grid( tf.transpose(images, perm=(1, 2, 3, 0))[:, :, :, 0:grid_side**2], grid_side) outputs = put_kernels_on_grid( tf.transpose(reconstructions, perm=(1, 2, 3, 0))[:, :, :, 0:grid_side**2], grid_side) tf_log( tf.summary.image('input_output', tf.concat(2, [inputs, outputs]), max_outputs=1)) # Calculate loss. loss = MODEL.loss(reconstructions, images) # reconstruction error error_ = tf.placeholder(tf.float32, shape=()) error = tf.summary.scalar('error', error_) if LR_DECAY: # Decay the learning rate exponentially based on the number of steps. learning_rate = tf.train.exponential_decay(INITIAL_LR, global_step, STEPS_PER_DECAY, LR_DECAY_FACTOR, staircase=True) else: learning_rate = tf.constant(INITIAL_LR) tf_log(tf.summary.scalar('learning_rate', learning_rate)) train_op = OPTIMIZER.minimize(loss, global_step=global_step) # Create the train saver. variables = variables_to_save([global_step]) train_saver = tf.train.Saver(variables, max_to_keep=2) # Create the best model saver best_saver = tf.train.Saver(variables, max_to_keep=1) # read collection after that every op added its own # summaries in the train_summaries collection train_summaries = tf.summary.merge( tf.get_collection_ref(MODEL_SUMMARIES)) # Build an initialization operation to run below. init = tf.variables_initializer(tf.global_variables() + tf.local_variables()) # Start running operations on the Graph. with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: sess.run(init) # Start the queue runners with a coordinator coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) if not RESTART: # continue from the saved checkpoint # restore previous session if exists checkpoint = tf.train.latest_checkpoint(LOG_DIR) if checkpoint: train_saver.restore(sess, checkpoint) else: print("[I] Unable to restore from checkpoint") train_log = tf.summary.FileWriter(os.path.join( LOG_DIR, str(InputType.train)), graph=sess.graph) validation_log = tf.summary.FileWriter(os.path.join( LOG_DIR, str(InputType.validation)), graph=sess.graph) # Extract previous global step value old_gs = sess.run(global_step) # Restart from where we were for step in range(old_gs, MAX_STEPS): start_time = time.time() _, loss_value = sess.run([train_op, loss], feed_dict={is_training_: True}) duration = time.time() - start_time if np.isnan(loss_value): print('Model diverged with loss = NaN') break # update logs every 10 iterations if step % 10 == 0: num_examples_per_step = BATCH_SIZE examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('{}: step {}, loss = {:.2f} ' '({:.1f} examples/sec; {:.3f} sec/batch)') print( format_str.format(datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) # log train error and summaries train_error_summary_line, train_summary_line = sess.run( [error, train_summaries], feed_dict={ error_: loss_value, is_training_: True }) train_log.add_summary(train_error_summary_line, global_step=step) train_log.add_summary(train_summary_line, global_step=step) # Save the model checkpoint at the end of every epoch # evaluate train and validation performance if (step > 0 and step % STEPS_PER_EPOCH == 0) or (step + 1) == MAX_STEPS: checkpoint_path = os.path.join(LOG_DIR, 'model.ckpt') train_saver.save(sess, checkpoint_path, global_step=step) # validation error validation_error_value = evaluate.error( LOG_DIR, MODEL, DATASET, InputType.validation, device=EVAL_DEVICE) summary_line = sess.run( error, feed_dict={error_: validation_error_value}) validation_log.add_summary(summary_line, global_step=step) print('{} ({}): train error = {} validation error = {}'. format(datetime.now(), int(step / STEPS_PER_EPOCH), loss_value, validation_error_value)) if validation_error_value < best_validation_error_value: best_validation_error_value = validation_error_value best_saver.save(sess, os.path.join(BEST_MODEL_DIR, 'model.ckpt'), global_step=step) # end of for validation_log.close() train_log.close() # When done, ask the threads to stop. coord.request_stop() # Wait for threads to finish. coord.join(threads) return best_validation_error_value