def __init__(self, model_path):
        #
        tfutil.init_tf()
        #
        print('[MD] Load insight face model...')
        with tf.io.gfile.GFile(model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def,
                            input_map=None,
                            return_elements=None,
                            name="")

        self.image_input = tf.get_default_graph().get_tensor_by_name('data:0')
        embedding = tf.get_default_graph().get_tensor_by_name('output:0')
        embedding_norm = tf.norm(embedding, axis=1, keepdims=True)
        self.embedding = tf.div(embedding,
                                embedding_norm,
                                name='norm_embedding')

        self.target_emb = tf.placeholder(tf.float32,
                                         shape=[None, 512],
                                         name='target_emb_input')
        self.cos_loss = tf.reduce_sum(
            tf.multiply(self.embedding, self.target_emb))
        self.l2_loss = tf.norm(self.embedding - self.target_emb)
        self.grads_op = tf.gradients(self.l2_loss, self.image_input)

        # self.fdict = {keep_prob:1.0, is_train:False}
        self.fdict = {}
        self.sess = tf.get_default_session()
Exemplo n.º 2
0
    def infer_stack(args):
        tmp = Image.open(args.stack)
        h,w = np.shape(tmp)
        N = tmp.n_frames

        imgs = np.zeros((N, 3, h, w))
        for i in range(N):
            tmp.seek(i)
            imgs[i, 0, :, :] = np.array(tmp)
            imgs[i, 1, :, :] = np.array(tmp)
            imgs[i, 2, :, :] = np.array(tmp)
        imgs = imgs.astype("float32")
        imgs = imgs / 255.0 - 0.5

        tfutil.init_tf(tf_config)
        net = util.load_snapshot(args.network)

        res = np.empty((N, h, w), dtype="uint16")
        for i in range(N):
            res[i,:,:] = util.infer_image_pp(net, imgs[i,:,:,:])

        #tmp = Image.fromarray(res[0,:,:,:].transpose([1,2,0]).astype("uint8"))
        tmp = Image.fromarray(res[0,:,:])
        tmp.save(args.out, format="tiff",
                 append_images=[Image.fromarray(res[i,:,:]) for i in range(1, res.shape[0])],
                 save_all=True)
Exemplo n.º 3
0
def init_tf(random_seed=1234):
    """Initialize TF."""
    print('Initializing TensorFlow...\n')
    np.random.seed(random_seed)
    tfutil.init_tf({
        'graph_options.place_pruned_graph': True,
        'gpu_options.allow_growth': True
    })
def infer_image(network_snapshot: str, image: str, out_image: str):
    tfutil.init_tf(config.tf_config)
    net = util.load_snapshot(network_snapshot)
    im = PIL.Image.open(image).convert('RGB')
    arr = np.array(im, dtype=np.float32)
    reshaped = arr.transpose([2, 0, 1]) / 255.0 - 0.5
    pred255 = util.infer_image(net, reshaped)
    t = pred255.transpose([1, 2, 0])  # [RGB, H, W] -> [H, W, RGB]
    PIL.Image.fromarray(t, 'RGB').save(os.path.join(out_image))
    print('Inferred image saved in', out_image)
Exemplo n.º 5
0
def validate(submit_config: dnnlib.SubmitConfig, noise: dict, dataset: dict, network_snapshot: str):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**dataset)

    ctx = dnnlib.RunContext(submit_config, config)

    tfutil.init_tf(config.tf_config)

    with tf.device("/gpu:0"):
        net = util.load_snapshot(network_snapshot)
        validation_set.evaluate(net, 0, noise_augmenter.add_validation_noise_np)
    ctx.close()
Exemplo n.º 6
0
def validate(submit_config: submit.SubmitConfig, tf_config: dict, noise: dict,
             dataset: dict, network_snapshot: str):
    noise_augmenter = noise.func(**noise.func_kwargs)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**dataset)

    ctx = RunContext(submit_config)

    tfutil.init_tf(tf_config)

    with tf.device("/gpu:0"):
        net = util.load_snapshot(network_snapshot)
        validation_set.evaluate(net, 0,
                                noise_augmenter.add_validation_noise_np)
    ctx.close()
Exemplo n.º 7
0
def validate(submit_config: dnnlib.SubmitConfig, noise: dict, dataset: dict,
             network_snapshot: str):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**dataset)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = load_snapshot(network_snapshot)
        validation_set.evaluate(net, 0,
                                noise_augmenter.add_validation_noise_np)
    ctx.close()
Exemplo n.º 8
0
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    # noinspection PyTypeChecker
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])

        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised = net_gpu.get_output_for(noisy_input_split[gpu])

            if noise2noise:
                meansq_error = tf.reduce_mean(
                    tf.square(noisy_target_split[gpu] - denoised))
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # ***********************************
    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break

        # Dump training status
        if i % eval_interval == 0:
            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()

            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.png".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.png".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s eta %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec',
                                   (time_train / eval_interval) *
                                   (iteration_count - i))),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

        # Training epoch
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    # End of training
    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()
Exemplo n.º 9
0
def train(
        submit_config: submit.SubmitConfig,
        iteration_count: int,
        eval_interval: int,
        minibatch_size: int,
        learning_rate: float,
        ramp_down_perc: float,
        noise: dict,
        tf_config: dict,
        net_config: dict,
        optimizer_config: dict,
        validation_config: dict,
        train_tfrecords: str):

    # **dict as argument means: take all additional named arguments to this function
    # and insert them into this parameter as dictionary entries.
    noise_augmenter = noise.func(**noise.func_kwargs)
    validation_set = ValidationSet(submit_config)
    # Load all images for validation as numpy arrays to the images attribute of the validation set.
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = run_context.RunContext(submit_config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(tf_config)

    # Creates the data set from the specified path to a generated tfrecords file containing all training images.
    # Data set will be split into minibatches of the given size and augment the noise with given noise function.
    # Use the dataset_tool_tf to create this tfrecords file.
    dataset_iter = create_dataset(train_tfrecords, minibatch_size, noise_augmenter.add_train_noise_tf)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = Network(**net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        # Placeholder for the learning rate. This will get ramped down dynamically.
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])

        # Defines the expression(s) that creates the network input.
        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)  # Split over multiple GPUs
        # clean_target_split = tf.split(clean_target, submit_config.num_gpus)

    # --------------------------------------------------------------------------------------------
    # Optimizer initialization and setup:

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = Optimizer(learning_rate=lrate_in, **optimizer_config)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            # Create a clone for this network for other gpus to work on.
            net_gpu = net if gpu == 0 else net.clone()

            # Create the output expression by giving the input expression into the network.
            denoised = net_gpu.get_output_for(noisy_input_split[gpu])

            # Create the error function as the MSE between the target tensor and the denoised network output.
            meansq_error = tf.reduce_mean(tf.square(noisy_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()  # Defines the update function of the optimizer.

    # Create a log file for Tensorboard
    summary_log = tf._api.v1.summary.FileWriter(submit_config.results_dir)
    summary_log.add_graph(tf.get_default_graph())

    # --------------------------------------------------------------------------------------------
    # Training and some milestone evaluation starts:

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update()  # TODO: why parameterized in reference?

    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break

        # Dump training status
        if i % eval_interval == 0:

            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()

            # Evaluate 'x' to draw one minbatch of inputs. Executes the operations defined in the dataset iterator.
            # Evals the noisy input and clean target minibatch Tensor ops to numpy array of the minibatch.
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            # Runs the noisy images through the network without training it. It is just for observing/evaluating.
            # net.run expects numpy arrays to run through this network.
            denoised = net.run(source_mb)
            # array shape: [minibatch_size, channel_size, height, width]
            util.save_image(submit_config, denoised[0], "img_{0}_y_pred.png".format(i))
            util.save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i))
            util.save_image(submit_config, source_mb[0], "img_{0}_x_aug.png".format(i))

            validation_set.evaluate(net, i, noise_augmenter.add_validation_noise_np)

            print('iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f' % (
                autosummary('Timing/iter', i),
                dnnlib.util.format_time(autosummary('Timing/total_sec', time_total)),
                autosummary('Timing/sec_per_eval', time_train),
                autosummary('Timing/sec_per_iter', time_train / eval_interval),
                autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update()
            time_maintenance = ctx.get_last_update_interval() - time_train

        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc, learning_rate)
        # Apply the lrate value to the lrate_in placeholder for the optimizer.
        tfutil.run([train_step], {lrate_in: lrate})  # Run the training update through the network in our session.

    print("Elapsed time: {0}".format(dutil.format_time(ctx.get_time_since_start())))
    util.save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()
Exemplo n.º 10
0
    import moviepy.editor  # pip install moviepy
    moviepy.editor.VideoClip(make_frame,
                             duration=duration_sec).write_videofile(
                                 os.path.join(result_subdir, mp4),
                                 fps=mp4_fps,
                                 codec='png',
                                 bitrate=mp4_bitrate)
    with open(os.path.join(result_subdir, mp4 + '-keyframes.txt'),
              'w') as file:
        file.write(str(latents_idx))


## -----------------------------------------------------------------------------------------------------------------

if __name__ == "__main__":
    import datetime
    import time
    print(datetime.datetime.now(), int(time.time()))
    np.random.seed(int(time.time()))
    tfutil.init_tf()
    generate_fake_images(00, num_pngs=3500)
    #generate_interpolation_video(12, grid_size=[1,1], random_seed=int(time.time()), mp4_fps=25, duration_sec=300.0)
    keyframes = [
        75, 125, 300, 375, 450, 550, 600, 700, 775, 800, 850, 900, 925, 975,
        1075, 1200, 1350, 1450, 1575, 1700, 1800, 1900, 2050, 2150, 2250, 2300,
        2375, 2475, 2800, 2875, 3000, 3050, 3125, 3225, 3300, 3400
    ]
    generate_keyframed_video(00, keyframes, mp4_fps=30)
    print('Exiting...')
    print(datetime.datetime.now())
Exemplo n.º 11
0
def main():

    starttime = int(time.time())
    parser = argparse.ArgumentParser(
        description='Render from StyleGAN2 saved models (pkl files).',
        #epilog=_examples,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument(
        '--network_pkl',
        help='The pkl file to render from (the model checkpoint).',
        default=None,
        metavar='MODEL.pkl',
        required=True)
    parser.add_argument(
        '--grid_x',
        help=
        'Number of images to render horizontally (each frame will have rows of X images, default: 1).',
        default=1,
        metavar='X',
        type=int)
    parser.add_argument(
        '--grid_y',
        help=
        'Number of images to render vertically (each frame will have cols of Y images, default: 1).',
        default=1,
        metavar='Y',
        type=int)
    parser.add_argument(
        '--png_sequence',
        help=
        'If True, outputs a folder of frames as pngs instead of video. (default: False).',
        default=False,
        type=bool)
    parser.add_argument(
        '--image_shrink',
        help=
        'Render in 1/[image_shrink] resolution (fast, useful for quick previews)',
        default=1,
        type=int)
    parser.add_argument(
        '--image_zoom',
        help=
        'Zoom on the output image (seems like just more video pixels, but no true upscaling)',
        default=1,
        type=float)
    parser.add_argument('--duration_sec',
                        help='Length of video to render in seconds.',
                        default=30.0,
                        type=float)
    parser.add_argument('--mp4_fps',
                        help='Frames per second for video rendering',
                        default=30,
                        type=float)
    parser.add_argument(
        '--smoothing_sec',
        help=
        'Gaussian kernel size in seconds to blend video frames (higher value = less change, lower value = more erratic, default: 1.0)',
        default=1.0,
        type=float)
    parser.add_argument(
        '--truncation_psi',
        help=
        'Truncation parameter (1 = normal, lower values overfit to look more like originals, higher values underfit to be more abstract, recommendation: 0.5-2)',
        default=1,
        type=float)
    parser.add_argument('--randomize_noise',
                        help='If True, adds noise to vary rendered images.',
                        default=False,
                        type=bool)
    parser.add_argument(
        '--filename',
        help='Filename for rendering output, defaults to pkl filename',
        default=None)
    parser.add_argument(
        '--mp4_codec',
        help='Video codec to use with moviepy (i.e. libx264, libx265, mpeg4)',
        default='libx265')
    parser.add_argument('--mp4_bitrate',
                        help='Bitrate to use with moviepy (i.e. 16M)',
                        default='16M')
    parser.add_argument('--random_seed',
                        help='Seed to initialize the latent generation.',
                        default=starttime,
                        type=int)
    parser.add_argument(
        '--minibatch_size',
        help=
        'Size of batch rendering (doesn\'t seem to have effects but left in anyway)',
        default=8,
        type=int)

    args = parser.parse_args()

    tfutil.init_tf()

    generate_interpolation_video(network_pkl=args.network_pkl,
                                 grid_size=[args.grid_x, args.grid_y],
                                 png_sequence=args.png_sequence,
                                 image_shrink=args.image_shrink,
                                 image_zoom=args.image_zoom,
                                 duration_sec=args.duration_sec,
                                 smoothing_sec=args.smoothing_sec,
                                 truncation_psi=args.truncation_psi,
                                 randomize_noise=args.randomize_noise,
                                 filename=args.filename,
                                 mp4_fps=args.mp4_fps,
                                 mp4_codec=args.mp4_codec,
                                 mp4_bitrate=args.mp4_bitrate,
                                 random_seed=args.random_seed,
                                 minibatch_size=args.minibatch_size)
Exemplo n.º 12
0
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)
    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.compat.v1.placeholder(tf.float32,
                                            name='lrate_in',
                                            shape=[])

        #print("DEBUG train:", "dataset iter got called")
        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        print(len(noisy_input_split), noisy_input_split)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)
        # Split [?, 3, 256, 256] across num_gpus over axis 0 (i.e. the batch)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)
    radii = np.arange(128).reshape(128, 1)  #image size 256, binning = 3
    radial_masks = np.apply_along_axis(radial_mask, 1, radii, 128, 128,
                                       np.arange(0, 256), np.arange(0, 256),
                                       20)
    print("RN SHAPE!!!!!!!!!!:", radial_masks.shape)
    radial_masks = np.expand_dims(radial_masks, 1)  # (128, 1, 256, 256)
    #radial_masks = np.squeeze(np.stack((radial_masks,) * 3, -1)) # 43, 3, 256, 256
    #radial_masks = radial_masks.transpose([0, 3, 1, 2])
    radial_masks = radial_masks.astype(np.complex64)
    radial_masks = tf.expand_dims(radial_masks, 1)

    rn = tf.compat.v1.placeholder_with_default(radial_masks,
                                               [128, None, 1, 256, 256])
    rn_split = tf.split(rn, submit_config.num_gpus, axis=1)
    freq_nyq = int(np.floor(int(256) / 2.0))

    spatial_freq = radii.astype(np.float32) / freq_nyq
    spatial_freq = spatial_freq / max(spatial_freq)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised_1 = net_gpu.get_output_for(noisy_input_split[gpu])
            denoised_2 = net_gpu.get_output_for(noisy_target_split[gpu])
            print(noisy_input_split[gpu].get_shape(),
                  rn_split[gpu].get_shape())
            if noise2noise:
                meansq_error = fourier_ring_correlation(
                    noisy_target_split[gpu], denoised_1, rn_split[gpu],
                    spatial_freq) - fourier_ring_correlation(
                        noisy_target_split[gpu] - denoised_2,
                        noisy_input_split[gpu] - denoised_1, rn_split[gpu],
                        spatial_freq)
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            #tf.summary.histogram(name, var)
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.compat.v1.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.compat.v1.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break
        # Dump training status
        if i % eval_interval == 0:

            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()
            print("DEBUG TRAIN!", noisy_input.dtype, noisy_input[0][0].dtype)
            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.tif".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.tif".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.tif".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

            save_snapshot(submit_config, net, str(i))
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()