Example #1
0
def test_train():
    # print options to help debugging
    # print(' '.join(sys.argv))

    # load the dataset
    dataloader = data.create_dataloader(opt)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses, iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image', trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
           epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Example #2
0
        # Training
        # train generator
        if i % opt.D_steps_per_G == 0:
            trainer.run_generator_one_step(data_i)

        # train discriminator
        trainer.run_discriminator_one_step(data_i)

        # Visualizations
        if iter_counter.needs_printing():
            losses = trainer.get_latest_losses()
            visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                            losses, iter_counter.time_per_iter)
            visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

        if iter_counter.needs_displaying():
            visuals = OrderedDict([('input_label', data_i['label']),
                                   ('synthesized_image', trainer.get_latest_generated()),
                                   ('real_image', data_i['image'])])
            visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far,
                                               iter_counter.epoch_iter)

        if iter_counter.needs_saving():
            if opt.local_rank == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

    trainer.update_learning_rate(epoch)
    iter_counter.record_epoch_end()
Example #3
0
def train():
    # create trainer for our model and freeze necessary model layers
    opt.niter = opt.niter + 20  # 20 more iterations of training
    opt.lr = 0.00002  # 1/10th of the original lr
    trainer = Pix2PixTrainer(opt)

    # Proceed with training.

    # load the dataset
    dataloader = data.create_dataloader(opt)

    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
           epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)
Example #4
0
File: train.py Project: zitkat/PDEL
def main(argv):
    if argv is None:
        argv = sys.argv[1:]

    opt = parse_options(argv)
    opt.isTrain = True

    dataset = ForcedIsotropicDataset(root_dir=opt.dataset_path)
    split = get_split(len(dataset), (.7, .1, .2))
    data_train, _, _ = torch.utils.data.random_split(
        dataset, lengths=split, generator=torch.Generator().manual_seed(42))

    dataloader = DataLoader(data_train, batch_size=opt.batchSize, shuffle=True)

    trainer = PDELTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, (time_i, data_i) in enumerate(dataloader):
            iter_counter.record_one_iteration()

            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            trainer.run_discriminator_one_step(data_i)

            trainer.update_learning_rate(epoch)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                iter_counter.record_current_errors(epoch,
                                                   iter_counter.epoch_iter,
                                                   losses,
                                                   iter_counter.time_per_iter)

            if iter_counter.needs_displaying():
                visualizer.save_paraview_snapshots(
                    epoch, iter_counter.epoch_iter, time_i[0], data_i[0],
                    trainer.get_latest_generated()[0])

            if iter_counter.needs_saving():
                iter_counter.printlog(
                    'saving the latest model '
                    f'(epoch {epoch}, '
                    f'total_steps {iter_counter.total_steps_so_far})')
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs:
            iter_counter.printlog('saving the model at the end of '
                                  f'epoch {epoch}, '
                                  f'iters {iter_counter.total_steps_so_far}')
            trainer.save('latest')
            trainer.save(epoch)
Example #5
0
        # train generator
        if i % opt.D_steps_per_G == 0:
            trainer.run_generator_one_step(data_i)

        # train discriminator
        trainer.run_discriminator_one_step(data_i)

        # Visualizations
        if local_rank == 0 and iter_counter.needs_printing():
            losses = trainer.get_latest_losses()
            visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                            losses, iter_counter.time_per_iter)
            visualizer.plot_current_errors(losses,
                                           iter_counter.total_steps_so_far)

        if local_rank == 0 and iter_counter.needs_displaying():
            visuals = OrderedDict([('input_label', data_i['label']),
                                   ('synthesized_image',
                                    trainer.get_latest_generated()),
                                   ('real_image', data_i['image'])])
            visualizer.display_current_results(visuals, epoch,
                                               iter_counter.total_steps_so_far)

        if local_rank == 0 and iter_counter.needs_saving():
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()

    trainer.update_learning_rate(epoch)
    if local_rank == 0:
Example #6
0
def do_train(opt):
    dataloader = data.create_dataloader(opt)
    # dataset [CustomDataset] of size 2000 was created

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)
    # Network [SPADEGenerator] was created. Total number of parameters: 92.5 million. To see the architecture, do print(network).
    # Network [MultiscaleDiscriminator] was created. Total number of parameters: 5.6 million. To see the architecture, do print(network).
    # Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
    # HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)
    # create web directory ./checkpoints/ipdb_test/web...

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            # data_i =
            # {'label': tensor([[[[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           ...,
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.],
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.],
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.]]]]), 'instance': tensor([0]), 'image': tensor([[[[-1.0000, -1.0000, -0.9922,  ...,  0.5529,  0.5529,  0.5529],
            #           [-1.0000, -1.0000, -0.9922,  ...,  0.5529,  0.5529,  0.5529],
            #           [-1.0000, -0.9922, -0.9843,  ...,  0.5529,  0.5529,  0.5529],
            #           ...,
            #           [ 0.4118,  0.4275,  0.4118,  ..., -0.7490, -0.7333, -0.7020],
            #           [ 0.4196,  0.4039,  0.4196,  ..., -0.7020, -0.7804, -0.7255],
            #           [ 0.4039,  0.4196,  0.4588,  ..., -0.6784, -0.7333, -0.6941]],

            #          [[-0.9529, -0.9686, -0.9843,  ...,  0.5843,  0.5843,  0.5843],
            #           [-0.9529, -0.9686, -0.9843,  ...,  0.5843,  0.5843,  0.5843],
            #           [-0.9608, -0.9686, -0.9765,  ...,  0.5843,  0.5843,  0.5843],
            #           ...,
            #           [ 0.4431,  0.4588,  0.4431,  ..., -0.8510, -0.8353, -0.8039],
            #           [ 0.4510,  0.4353,  0.4510,  ..., -0.8039, -0.8824, -0.8275],
            #           [ 0.4353,  0.4510,  0.4902,  ..., -0.7725, -0.8275, -0.7882]],

            #          [[-0.9843, -1.0000, -1.0000,  ...,  0.6549,  0.6549,  0.6549],
            #           [-0.9843, -1.0000, -1.0000,  ...,  0.6549,  0.6549,  0.6549],
            #           [-0.9922, -1.0000, -0.9922,  ...,  0.6549,  0.6549,  0.6549],
            #           ...,
            #           [ 0.5294,  0.5451,  0.5294,  ..., -0.9216, -0.8980, -0.8667],
            #           [ 0.5373,  0.5216,  0.5373,  ..., -0.8824, -0.9529, -0.8980],
            #           [ 0.5216,  0.5373,  0.5765,  ..., -0.8667, -0.9216, -0.8824]]]]), 'path': ['../../Celeb_subset/train/images/8516.jpg']}
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
        epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Example #7
0
def run(opt):
    print("Number of GPUs used: {}".format(torch.cuda.device_count()))
    print("Current Experiment Name: {}".format(opt.name))

    # The dataloader will yield the training samples
    dataloader = data.create_dataloader(opt)

    trainer = TrainerManager(opt)
    inference_manager = InferenceManager(
        num_samples=opt.num_evaluation_samples,
        opt=opt,
        cuda=len(opt.gpu_ids) > 0,
        write_details=False,
        save_images=False)

    # For logging and visualizations
    iter_counter = IterationCounter(opt, len(dataloader))
    visualizer = Visualizer(opt)

    if not opt.debug:
        # We keep a copy of the current source code for each experiment
        copy_src(path_from="./",
                 path_to=os.path.join(opt.checkpoints_dir, opt.name))

    # We wrap training into a try/except clause such that the model is saved
    # when interrupting with Ctrl+C
    try:
        for epoch in iter_counter.training_epochs():
            iter_counter.record_epoch_start(epoch)
            for i, data_i in enumerate(dataloader,
                                       start=iter_counter.epoch_iter):

                # Training the generator
                if i % opt.D_steps_per_G == 0:
                    trainer.run_generator_one_step(data_i)

                # Training the discriminator
                trainer.run_discriminator_one_step(data_i)

                iter_counter.record_one_iteration()

                # Logging, plotting and visualizing
                if iter_counter.needs_printing():
                    losses = trainer.get_latest_losses()
                    visualizer.print_current_errors(
                        epoch, iter_counter.epoch_iter, losses,
                        iter_counter.time_per_iter,
                        iter_counter.total_time_so_far,
                        iter_counter.total_steps_so_far)
                    visualizer.plot_current_errors(
                        losses, iter_counter.total_steps_so_far)

                if iter_counter.needs_displaying():
                    logs = trainer.get_logs()
                    visuals = [('input_label', data_i['label']),
                               ('out_train', trainer.get_latest_generated()),
                               ('real_train', data_i['image'])]
                    if opt.guiding_style_image:
                        visuals.append(
                            ('guiding_image', data_i['guiding_image']))
                        visuals.append(
                            ('guiding_input_label', data_i['guiding_label']))

                    if opt.evaluate_val_set:
                        validation_output = inference_validation(
                            trainer.sr_model, inference_manager, opt)
                        visuals += validation_output
                    visuals = OrderedDict(visuals)
                    visualizer.display_current_results(
                        visuals, epoch, iter_counter.total_steps_so_far, logs)

                if iter_counter.needs_saving():
                    print(
                        'Saving the latest model (epoch %d, total_steps %d)' %
                        (epoch, iter_counter.total_steps_so_far))
                    trainer.save('latest')
                    iter_counter.record_current_iter()

                if iter_counter.needs_evaluation():
                    # Evaluate on training set
                    result_train = evaluate_training_set(
                        inference_manager, trainer.sr_model_on_one_gpu,
                        dataloader)
                    info = iter_counter.record_fid(
                        result_train["FID"],
                        split="train",
                        num_samples=opt.num_evaluation_samples)
                    info += os.linesep + iter_counter.record_metrics(
                        result_train, split="train")
                    visualizer.plot_current_errors(
                        result_train,
                        iter_counter.total_steps_so_far,
                        split="train/")

                    if opt.evaluate_val_set:
                        # Evaluate on validation set
                        result_val = evaluate_validation_set(
                            inference_manager, trainer.sr_model_on_one_gpu,
                            opt)
                        info += os.linesep + iter_counter.record_fid(
                            result_val["FID"],
                            split="validation",
                            num_samples=opt.num_evaluation_samples)
                        info += os.linesep + iter_counter.record_metrics(
                            result_val, split="validation")
                        visualizer.plot_current_errors(
                            result_val,
                            iter_counter.total_steps_so_far,
                            split="validation/")

            trainer.update_learning_rate(epoch)
            iter_counter.record_epoch_end()

            if epoch % opt.save_epoch_freq == 0 or \
                            epoch == iter_counter.total_epochs:
                print('Saving the model at the end of epoch %d, iters %d' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                trainer.save(epoch)
                iter_counter.record_current_iter()

        print('Training was successfully finished.')
    except (KeyboardInterrupt, SystemExit):
        print("KeyboardInterrupt. Shutting down.")
        print(traceback.format_exc())
    except Exception as e:
        print(traceback.format_exc())
    finally:
        print('Saving the model before quitting')
        trainer.save('latest')
        iter_counter.record_current_iter()