Ejemplo n.º 1
0
def test_train():
    # print options to help debugging
    # print(' '.join(sys.argv))

    # load the dataset
    dataloader = data.create_dataloader(opt)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses, iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image', trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
           epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Ejemplo n.º 2
0
# create trainer for our model
trainer = trainers.create_trainer(opt)

# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))

# create tool for visualization
if opt.distributed:
    if opt.no_all_gather_outputs:
        visualizer = ApexVisualizer(opt)
    else:
        visualizer = AsyncVisualizer(opt)
else:
    visualizer = Visualizer(opt)

for epoch in iter_counter.training_epochs():
    iter_counter.record_epoch_start(epoch)
    for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
        iter_counter.record_one_iteration()

        # Training
        # train generator
        if i % opt.D_steps_per_G == 0:
            trainer.run_generator_one_step(data_i)

        # train discriminator
        trainer.run_discriminator_one_step(data_i)

        # Visualizations
        if iter_counter.needs_printing():
            losses = trainer.get_latest_losses()
Ejemplo n.º 3
0
def train():
    # create trainer for our model and freeze necessary model layers
    opt.niter = opt.niter + 20  # 20 more iterations of training
    opt.lr = 0.00002  # 1/10th of the original lr
    trainer = Pix2PixTrainer(opt)

    # Proceed with training.

    # load the dataset
    dataloader = data.create_dataloader(opt)

    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
           epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)
Ejemplo n.º 4
0
def main_worker(gpu, world_size, idx_server, opt):
    print('Use GPU: {} for training'.format(gpu))
    ngpus_per_node = world_size
    world_size = opt.world_size
    rank = idx_server * ngpus_per_node + gpu
    opt.gpu = gpu
    dist.init_process_group(backend='nccl',
                            init_method=opt.dist_url,
                            world_size=world_size,
                            rank=rank)
    torch.cuda.set_device(opt.gpu)

    # load the dataset
    dataloader = data.create_dataloader(opt, world_size, rank)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader), world_size, rank)

    # create tool for visualization
    visualizer = Visualizer(opt, rank)

    for epoch in iter_counter.training_epochs():
        # set epoch for data sampler
        dataloader.sampler.set_epoch(epoch)

        iter_counter.record_epoch_start(epoch)

        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

        visuals = OrderedDict([('input_label', data_i['label']),
                               ('synthesized_image',
                                trainer.get_latest_generated()),
                               ('real_image', data_i['image'])])
        visualizer.display_current_results(visuals, epoch,
                                           iter_counter.total_steps_so_far)

        if rank == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if (epoch % opt.save_epoch_freq == 0
                or epoch == iter_counter.total_epochs) and (rank == 0):
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save(epoch)

    print('Training was successfully finished.')
Ejemplo n.º 5
0
Archivo: train.py Proyecto: zitkat/PDEL
def main(argv):
    if argv is None:
        argv = sys.argv[1:]

    opt = parse_options(argv)
    opt.isTrain = True

    dataset = ForcedIsotropicDataset(root_dir=opt.dataset_path)
    split = get_split(len(dataset), (.7, .1, .2))
    data_train, _, _ = torch.utils.data.random_split(
        dataset, lengths=split, generator=torch.Generator().manual_seed(42))

    dataloader = DataLoader(data_train, batch_size=opt.batchSize, shuffle=True)

    trainer = PDELTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, (time_i, data_i) in enumerate(dataloader):
            iter_counter.record_one_iteration()

            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            trainer.run_discriminator_one_step(data_i)

            trainer.update_learning_rate(epoch)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                iter_counter.record_current_errors(epoch,
                                                   iter_counter.epoch_iter,
                                                   losses,
                                                   iter_counter.time_per_iter)

            if iter_counter.needs_displaying():
                visualizer.save_paraview_snapshots(
                    epoch, iter_counter.epoch_iter, time_i[0], data_i[0],
                    trainer.get_latest_generated()[0])

            if iter_counter.needs_saving():
                iter_counter.printlog(
                    'saving the latest model '
                    f'(epoch {epoch}, '
                    f'total_steps {iter_counter.total_steps_so_far})')
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs:
            iter_counter.printlog('saving the model at the end of '
                                  f'epoch {epoch}, '
                                  f'iters {iter_counter.total_steps_so_far}')
            trainer.save('latest')
            trainer.save(epoch)
Ejemplo n.º 6
0
def do_train(opt):
    dataloader = data.create_dataloader(opt)
    # dataset [CustomDataset] of size 2000 was created

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)
    # Network [SPADEGenerator] was created. Total number of parameters: 92.5 million. To see the architecture, do print(network).
    # Network [MultiscaleDiscriminator] was created. Total number of parameters: 5.6 million. To see the architecture, do print(network).
    # Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
    # HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)
    # create web directory ./checkpoints/ipdb_test/web...

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            # data_i =
            # {'label': tensor([[[[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           ...,
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.],
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.],
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.]]]]), 'instance': tensor([0]), 'image': tensor([[[[-1.0000, -1.0000, -0.9922,  ...,  0.5529,  0.5529,  0.5529],
            #           [-1.0000, -1.0000, -0.9922,  ...,  0.5529,  0.5529,  0.5529],
            #           [-1.0000, -0.9922, -0.9843,  ...,  0.5529,  0.5529,  0.5529],
            #           ...,
            #           [ 0.4118,  0.4275,  0.4118,  ..., -0.7490, -0.7333, -0.7020],
            #           [ 0.4196,  0.4039,  0.4196,  ..., -0.7020, -0.7804, -0.7255],
            #           [ 0.4039,  0.4196,  0.4588,  ..., -0.6784, -0.7333, -0.6941]],

            #          [[-0.9529, -0.9686, -0.9843,  ...,  0.5843,  0.5843,  0.5843],
            #           [-0.9529, -0.9686, -0.9843,  ...,  0.5843,  0.5843,  0.5843],
            #           [-0.9608, -0.9686, -0.9765,  ...,  0.5843,  0.5843,  0.5843],
            #           ...,
            #           [ 0.4431,  0.4588,  0.4431,  ..., -0.8510, -0.8353, -0.8039],
            #           [ 0.4510,  0.4353,  0.4510,  ..., -0.8039, -0.8824, -0.8275],
            #           [ 0.4353,  0.4510,  0.4902,  ..., -0.7725, -0.8275, -0.7882]],

            #          [[-0.9843, -1.0000, -1.0000,  ...,  0.6549,  0.6549,  0.6549],
            #           [-0.9843, -1.0000, -1.0000,  ...,  0.6549,  0.6549,  0.6549],
            #           [-0.9922, -1.0000, -0.9922,  ...,  0.6549,  0.6549,  0.6549],
            #           ...,
            #           [ 0.5294,  0.5451,  0.5294,  ..., -0.9216, -0.8980, -0.8667],
            #           [ 0.5373,  0.5216,  0.5373,  ..., -0.8824, -0.9529, -0.8980],
            #           [ 0.5216,  0.5373,  0.5765,  ..., -0.8667, -0.9216, -0.8824]]]]), 'path': ['../../Celeb_subset/train/images/8516.jpg']}
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
        epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Ejemplo n.º 7
0
Archivo: train.py Proyecto: sj-li/TSIT
# print options to help debugging
print(' '.join(sys.argv))

# load the dataset
dataloader = data.create_dataloader(opt)

# create trainer for our model
trainer = Pix2PixTrainer(opt)

# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))

# create tool for visualization
visualizer = Visualizer(opt)

for epoch in tqdm(iter_counter.training_epochs()):
    iter_counter.record_epoch_start(epoch)
    for i, data_i in enumerate(tqdm(dataloader), start=iter_counter.epoch_iter):
        iter_counter.record_one_iteration()

        # Training
        # train generator
        if i % opt.D_steps_per_G == 0:
            trainer.run_generator_one_step(data_i)

        # train discriminator
        trainer.run_discriminator_one_step(data_i)

        # Visualizations
        if iter_counter.needs_printing():
            losses = trainer.get_latest_losses()
Ejemplo n.º 8
0
def do_train(opt):
    # print options to help debugging
    print(' '.join(sys.argv))

    # load the dataset
    dataloader = data.create_dataloader(opt)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    if opt.train_eval:
        # val_opt = TestOptions().parse()
        original_flip = opt.no_flip
        opt.no_flip = True
        opt.phase = 'test'
        opt.isTrain = False
        dataloader_val = data.create_dataloader(opt)
        val_visualizer = Visualizer(opt)
        # # create a webpage that summarizes the all results
        web_dir = os.path.join(opt.results_dir, opt.name,
                            '%s_%s' % (opt.phase, opt.which_epoch))
        webpage = html.HTML(web_dir,
                            'Experiment = %s, Phase = %s, Epoch = %s' %
                            (opt.name, opt.phase, opt.which_epoch))
        opt.phase = 'train'
        opt.isTrain = True
        opt.no_flip = original_flip
        # process for calculate FID scores
        from inception import InceptionV3
        from fid_score import calculate_fid_given_paths
        import pathlib
        # define the inceptionV3
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[opt.eval_dims]
        eval_model = InceptionV3([block_idx]).cuda()
        # load real images distributions on the training set
        mu_np_root = os.path.join('datasets/train_mu_si',opt.dataset_mode,'m.npy')
        st_np_root = os.path.join('datasets/train_mu_si',opt.dataset_mode,'s.npy')
        m0, s0 = np.load(mu_np_root), np.load(st_np_root)
        # load previous best FID
        if opt.continue_train:
            fid_record_dir = os.path.join(opt.checkpoints_dir, opt.name, 'fid.txt')
            FID_score, _ = np.loadtxt(fid_record_dir, delimiter=',', dtype=float)
        else:
            FID_score = 1000
    else:
        FID_score = 1000      

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                if opt.train_eval:
                    visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                    losses, iter_counter.time_per_iter, FID_score)
                else:
                    visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                    losses, iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

            # if iter_counter.needs_displaying():
            #     visuals = OrderedDict([('input_label', data_i['label']),
            #                            ('synthesized_image', trainer.get_latest_generated()),
            #                            ('real_image', data_i['image'])])
            #     visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter(FID_score)

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.eval_epoch_freq == 0 and opt.train_eval:
            # generate fake image
            trainer.pix2pix_model.eval()
            print('start evalidation .... ')
            if opt.use_vae:
                flag = True
                opt.use_vae = False
            else:
                flag = False
            for i, data_i in enumerate(dataloader_val):
                if data_i['label'].size()[0] != opt.batchSize:
                    if opt.batchSize > 2*data_i['label'].size()[0]:
                        print('batch size is too large')
                        break
                    data_i = repair_data(data_i, opt.batchSize)
                generated = trainer.pix2pix_model(data_i, mode='inference')
                img_path = data_i['path']
                for b in range(generated.shape[0]):
                    tmp = tensor2im(generated[b])
                    visuals = OrderedDict([('input_label', data_i['label'][b]),
                                        ('synthesized_image', generated[b])])
                    val_visualizer.save_images(webpage, visuals, img_path[b:b + 1])
            webpage.save()
            trainer.pix2pix_model.train()
            if flag:
                opt.use_vae = True
            # cal fid score
            fake_path = pathlib.Path(os.path.join(web_dir, 'images/synthesized_image/'))
            files = list(fake_path.glob('*.jpg')) + list(fake_path.glob('*.png'))
            m1, s1 = calculate_activation_statistics(files, eval_model, 1, opt.eval_dims, True, images=None)
            fid_value = calculate_frechet_distance(m0, s0, m1, s1)
            visualizer.print_eval_fids(epoch, fid_value, FID_score)
            # save the best model if necessary
            if fid_value < FID_score:
                FID_score = fid_value
                trainer.save('best')

        if epoch % opt.save_epoch_freq == 0 or \
        epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Ejemplo n.º 9
0
def run(opt):
    print("Number of GPUs used: {}".format(torch.cuda.device_count()))
    print("Current Experiment Name: {}".format(opt.name))

    # The dataloader will yield the training samples
    dataloader = data.create_dataloader(opt)

    trainer = TrainerManager(opt)
    inference_manager = InferenceManager(
        num_samples=opt.num_evaluation_samples,
        opt=opt,
        cuda=len(opt.gpu_ids) > 0,
        write_details=False,
        save_images=False)

    # For logging and visualizations
    iter_counter = IterationCounter(opt, len(dataloader))
    visualizer = Visualizer(opt)

    if not opt.debug:
        # We keep a copy of the current source code for each experiment
        copy_src(path_from="./",
                 path_to=os.path.join(opt.checkpoints_dir, opt.name))

    # We wrap training into a try/except clause such that the model is saved
    # when interrupting with Ctrl+C
    try:
        for epoch in iter_counter.training_epochs():
            iter_counter.record_epoch_start(epoch)
            for i, data_i in enumerate(dataloader,
                                       start=iter_counter.epoch_iter):

                # Training the generator
                if i % opt.D_steps_per_G == 0:
                    trainer.run_generator_one_step(data_i)

                # Training the discriminator
                trainer.run_discriminator_one_step(data_i)

                iter_counter.record_one_iteration()

                # Logging, plotting and visualizing
                if iter_counter.needs_printing():
                    losses = trainer.get_latest_losses()
                    visualizer.print_current_errors(
                        epoch, iter_counter.epoch_iter, losses,
                        iter_counter.time_per_iter,
                        iter_counter.total_time_so_far,
                        iter_counter.total_steps_so_far)
                    visualizer.plot_current_errors(
                        losses, iter_counter.total_steps_so_far)

                if iter_counter.needs_displaying():
                    logs = trainer.get_logs()
                    visuals = [('input_label', data_i['label']),
                               ('out_train', trainer.get_latest_generated()),
                               ('real_train', data_i['image'])]
                    if opt.guiding_style_image:
                        visuals.append(
                            ('guiding_image', data_i['guiding_image']))
                        visuals.append(
                            ('guiding_input_label', data_i['guiding_label']))

                    if opt.evaluate_val_set:
                        validation_output = inference_validation(
                            trainer.sr_model, inference_manager, opt)
                        visuals += validation_output
                    visuals = OrderedDict(visuals)
                    visualizer.display_current_results(
                        visuals, epoch, iter_counter.total_steps_so_far, logs)

                if iter_counter.needs_saving():
                    print(
                        'Saving the latest model (epoch %d, total_steps %d)' %
                        (epoch, iter_counter.total_steps_so_far))
                    trainer.save('latest')
                    iter_counter.record_current_iter()

                if iter_counter.needs_evaluation():
                    # Evaluate on training set
                    result_train = evaluate_training_set(
                        inference_manager, trainer.sr_model_on_one_gpu,
                        dataloader)
                    info = iter_counter.record_fid(
                        result_train["FID"],
                        split="train",
                        num_samples=opt.num_evaluation_samples)
                    info += os.linesep + iter_counter.record_metrics(
                        result_train, split="train")
                    visualizer.plot_current_errors(
                        result_train,
                        iter_counter.total_steps_so_far,
                        split="train/")

                    if opt.evaluate_val_set:
                        # Evaluate on validation set
                        result_val = evaluate_validation_set(
                            inference_manager, trainer.sr_model_on_one_gpu,
                            opt)
                        info += os.linesep + iter_counter.record_fid(
                            result_val["FID"],
                            split="validation",
                            num_samples=opt.num_evaluation_samples)
                        info += os.linesep + iter_counter.record_metrics(
                            result_val, split="validation")
                        visualizer.plot_current_errors(
                            result_val,
                            iter_counter.total_steps_so_far,
                            split="validation/")

            trainer.update_learning_rate(epoch)
            iter_counter.record_epoch_end()

            if epoch % opt.save_epoch_freq == 0 or \
                            epoch == iter_counter.total_epochs:
                print('Saving the model at the end of epoch %d, iters %d' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                trainer.save(epoch)
                iter_counter.record_current_iter()

        print('Training was successfully finished.')
    except (KeyboardInterrupt, SystemExit):
        print("KeyboardInterrupt. Shutting down.")
        print(traceback.format_exc())
    except Exception as e:
        print(traceback.format_exc())
    finally:
        print('Saving the model before quitting')
        trainer.save('latest')
        iter_counter.record_current_iter()