def train(sv,
          sess,
          data,
          max_steps,
          display_fetches,
          display_fetches_test,
          dataTest,
          saver,
          loss,
          output_dir=a.output_dir):
    try:
        # training
        start_time = time.time()
        sess.run(data.iterator.initializer)

        #For as many steps as required
        for step in range(max_steps):
            options = None
            run_metadata = None
            if helpers.should(a.trace_freq, max_steps, step):
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            #Define the variable to evaluate for tf for any train step.
            fetches = {
                "train": loss.trainOp,
                "global_step": sv.global_step,
            }

            #Add variable to evaluate depending on the current step
            if helpers.should(a.progress_freq, max_steps, step) or step <= 1:
                fetches["loss_value"] = loss.lossValue

            #Add variable to evaluate depending on the current step
            if helpers.should(a.summary_freq, max_steps, step):
                fetches["summary"] = sv.summary_op

            try:
                currentLrValue = a.lr
                if a.checkpoint is None and step < 2000:
                    currentLrValue = step * (
                        0.0005
                    ) * a.lr  # ramps up to a.lr in the 2000 first iterations to avoid crazy first gradients to have too much impact.

                #Run the network
                results = sess.run(fetches,
                                   feed_dict={loss.lr: currentLrValue},
                                   options=options,
                                   run_metadata=run_metadata)
            except tf.errors.OutOfRangeError:
                print(
                    "training fails in OutOfRangeError, probably a problem with the iterator"
                )
                continue

            #Get the current global step from the network results
            global_step = results["global_step"]

            if helpers.should(a.summary_freq, max_steps, step):
                #Add results of rendering to tensorboard is the step is right.
                sv.summary_writer.add_summary(results["summary"], global_step)

            if helpers.should(a.trace_freq, max_steps, step):
                print("recording trace")
                sv.summary_writer.add_run_metadata(run_metadata,
                                                   "step_%d" % global_step)

            if helpers.should(a.progress_freq, max_steps, step):
                #Print information about the training
                train_epoch = math.ceil(
                    global_step / data.stepsPerEpoch
                )  # global_step will have the correct step count even if we resume from a checkpoint
                train_step = global_step - (train_epoch -
                                            1) * data.stepsPerEpoch
                imagesPerSecond = global_step * a.batch_size / (time.time() -
                                                                start_time)
                remainingMinutes = ((max_steps - global_step) *
                                    a.batch_size) / (imagesPerSecond * 60)
                print("progress  epoch %d  step %d  image/sec %0.1f" %
                      (train_epoch, train_step, imagesPerSecond))
                print("Remaining %0.1f minutes" % (remainingMinutes))
                print("loss_value", results["loss_value"])

            if helpers.should(a.save_freq, max_steps, step):
                #Saves the model of current step.
                print("saving model")
                saver.save(sess,
                           os.path.join(output_dir, "model"),
                           global_step=sv.global_step)

            if helpers.should(a.test_freq, max_steps,
                              step) or global_step == 1:
                #Run the test set against the currently training network.
                outputTestDir = os.path.join(a.output_dir, str(global_step))
                test(sess, dataTest, max_steps, display_fetches_test,
                     outputTestDir)
            if sv.should_stop():
                break
    finally:
        #Save everything and run one last test.
        saver.save(sess,
                   os.path.join(output_dir, "model"),
                   global_step=sv.global_step)
        sess.run(data.iterator.initializer)
        outputTestDir = os.path.join(a.output_dir, "final")
        test(sess, dataTest, max_steps, display_fetches_test, outputTestDir)
def train(sv, sess, data, max_steps, display_fetches, display_fetches_test, dataTest, saver, loss, output_dir = a.output_dir):
    sess.run(data.iterator.initializer)
    try:
        # training
        start_time = time.time()

        for step in range(max_steps):
            options = None
            run_metadata = None
            if helpers.should(a.trace_freq, max_steps, step):
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            fetches = {
                "train": loss.trainOp,
                "global_step": sv.global_step,
            }

            if helpers.should(a.progress_freq, max_steps, step) or step <= 1:
                fetches["loss_value"] = loss.lossValue

            if helpers.should(a.summary_freq, max_steps, step):
                fetches["summary"] = sv.summary_op

            fetches["display"] = display_fetches
            try:
                currentLrValue = a.lr
                if a.checkpoint is None and step < 500:
                    currentLrValue = step * (0.002) * a.lr # ramps up to a.lr in the 2000 first iterations to avoid crazy first gradients to have too much impact.

                results = sess.run(fetches, feed_dict={loss.lr: currentLrValue}, options=options, run_metadata=run_metadata)
            except tf.errors.OutOfRangeError :
                print("training fails in OutOfRangeError, probably a problem with the iterator")
                continue

            global_step = results["global_step"]
            
            #helpers.saveInputs(a.output_dir, results["display"], step)

            if helpers.should(a.summary_freq, max_steps, step):
                sv.summary_writer.add_summary(results["summary"], global_step)

            if helpers.should(a.trace_freq, max_steps, step):
                print("recording trace")
                sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % global_step)

            if helpers.should(a.progress_freq, max_steps, step):
                # global_step will have the correct step count if we resume from a checkpoint
                train_epoch = math.ceil(global_step / data.stepsPerEpoch)
                train_step = global_step - (train_epoch - 1) * data.stepsPerEpoch
                imagesPerSecond = global_step * a.batch_size / (time.time() - start_time)
                remainingMinutes = ((max_steps - global_step) * a.batch_size)/(imagesPerSecond * 60)
                print("progress  epoch %d  step %d  image/sec %0.1f" % (train_epoch, global_step, imagesPerSecond))
                print("Remaining %0.1f minutes" % (remainingMinutes))
                print("loss_value", results["loss_value"])

            if helpers.should(a.save_freq, max_steps, step):
                print("saving model")
                try:
                    saver.save(sess, os.path.join(output_dir, "model"), global_step=sv.global_step)
                except Exception as e:
                    print("Didn't manage to save model (trainining continues): " + str(e))

            if helpers.should(a.test_freq, max_steps, step) or global_step == 1:
                outputTestDir = os.path.join(a.output_dir, str(global_step))
                try:
                    test(sess, dataTest, max_steps, display_fetches_test, outputTestDir)
                except Exception as e:
                    print("Didn't manage to do a recurrent test (trainining continues): " + str(e))

            if sv.should_stop():
                break
    finally:
        saver.save(sess, os.path.join(output_dir, "model"), global_step=sv.global_step) #Does the saver saves everything still ?
        sess.run(data.iterator.initializer)
        outputTestDir = os.path.join(a.output_dir, "final")
        test(sess, dataTest, max_steps, display_fetches_test, outputTestDir )