Exemplo n.º 1
0
def test_train():
    # print options to help debugging
    # print(' '.join(sys.argv))

    # load the dataset
    dataloader = data.create_dataloader(opt)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses, iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image', trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
           epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
def build_model_and_get_trainer(config: DotMap, data_loader: DataLoader,
                                strategy: tf.distribute.Strategy) -> Trainer:
    model_structure = config.model.structure

    print('Create the model')
    if model_structure == 'pix2pix':
        with strategy.scope():
            generator = get_generator_model(config)
            discriminator = get_discriminator_model(config)

        trainer = Pix2PixTrainer(generator=generator,
                                 discriminator=discriminator,
                                 data_loader=data_loader,
                                 strategy=strategy,
                                 config=config)

        return trainer
    else:
        raise ValueError(f"unknown model structure {model_structure}")
Exemplo n.º 3
0
import data
from util.iter_counter import IterationCounter
from util.visualizer import Visualizer
from trainers.pix2pix_trainer import Pix2PixTrainer

# parse options
opt = TrainOptions().parse()

# print options to help debugging
print(' '.join(sys.argv))

# load the dataset
dataloader = data.create_dataloader(opt)

# create trainer for our model
trainer = Pix2PixTrainer(opt)

# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))

# create tool for visualization
visualizer = Visualizer(opt)

for epoch in iter_counter.training_epochs():
    iter_counter.record_epoch_start(epoch)
    for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
        iter_counter.record_one_iteration()

        # Training
        # train generator
        if i % opt.D_steps_per_G == 0:
Exemplo n.º 4
0
def train():
    # create trainer for our model and freeze necessary model layers
    opt.niter = opt.niter + 20  # 20 more iterations of training
    opt.lr = 0.00002  # 1/10th of the original lr
    trainer = Pix2PixTrainer(opt)

    # Proceed with training.

    # load the dataset
    dataloader = data.create_dataloader(opt)

    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
           epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)
Exemplo n.º 5
0
    parser.add_argument('--save_freq', type=int, default=10)
    parser.add_argument('--print_loss_freq', type=int, default=40)
    parser.add_argument('--eval_mode', type=bool, default=False)

    # data loader
    parser.add_argument('--workers', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=1)

    # model hyperparameters
    parser.add_argument('--inner_channels', type=int, default=64)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--norm', type=str, default='batch')

    # training hyperparameters
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--decay_ratio', type=float, default=0.5)
    parser.add_argument('--lamb', type=float, default=10.0)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)

    args = parser.parse_args()

    trainer = None
    if args.model == 'pix2pix':
        trainer = Pix2PixTrainer(args=args)

    if args.op == 'train':
        trainer.train()
    else:
        trainer.test(args.checkpoint)
Exemplo n.º 6
0
def main_worker(gpu, world_size, idx_server, opt):
    print('Use GPU: {} for training'.format(gpu))
    ngpus_per_node = world_size
    world_size = opt.world_size
    rank = idx_server * ngpus_per_node + gpu
    opt.gpu = gpu
    dist.init_process_group(backend='nccl',
                            init_method=opt.dist_url,
                            world_size=world_size,
                            rank=rank)
    torch.cuda.set_device(opt.gpu)

    # load the dataset
    dataloader = data.create_dataloader(opt, world_size, rank)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader), world_size, rank)

    # create tool for visualization
    visualizer = Visualizer(opt, rank)

    for epoch in iter_counter.training_epochs():
        # set epoch for data sampler
        dataloader.sampler.set_epoch(epoch)

        iter_counter.record_epoch_start(epoch)

        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

        visuals = OrderedDict([('input_label', data_i['label']),
                               ('synthesized_image',
                                trainer.get_latest_generated()),
                               ('real_image', data_i['image'])])
        visualizer.display_current_results(visuals, epoch,
                                           iter_counter.total_steps_so_far)

        if rank == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if (epoch % opt.save_epoch_freq == 0
                or epoch == iter_counter.total_epochs) and (rank == 0):
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save(epoch)

    print('Training was successfully finished.')
Exemplo n.º 7
0
opt = TrainOptions().parse()

# print options to help debugging
print(' '.join(sys.argv))

#torch.manual_seed(0)
# load the dataset
dataloader = data.create_dataloader(opt)
len_dataloader = len(dataloader)
dataloader.dataset[11]

# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))

# create trainer for our model
trainer = Pix2PixTrainer(opt, resume_epoch=iter_counter.first_epoch)

save_root = os.path.join(os.path.dirname(opt.checkpoints_dir), 'output', opt.name)
for epoch in iter_counter.training_epochs():
    opt.epoch = epoch
    if not opt.maskmix:
        print('inject nothing')
    elif opt.maskmix and opt.noise_for_mask and epoch > opt.mask_epoch:
        print('inject noise')
    else:
         print('inject mask')
    print('real_reference_probability is :{}'.format(dataloader.dataset.real_reference_probability))
    print('hard_reference_probability is :{}'.format(dataloader.dataset.hard_reference_probability))
    iter_counter.record_epoch_start(epoch)
    for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
        iter_counter.record_one_iteration()
Exemplo n.º 8
0
def do_train(opt):
    dataloader = data.create_dataloader(opt)
    # dataset [CustomDataset] of size 2000 was created

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)
    # Network [SPADEGenerator] was created. Total number of parameters: 92.5 million. To see the architecture, do print(network).
    # Network [MultiscaleDiscriminator] was created. Total number of parameters: 5.6 million. To see the architecture, do print(network).
    # Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
    # HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)
    # create web directory ./checkpoints/ipdb_test/web...

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            # data_i =
            # {'label': tensor([[[[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           ...,
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.],
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.],
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.]]]]), 'instance': tensor([0]), 'image': tensor([[[[-1.0000, -1.0000, -0.9922,  ...,  0.5529,  0.5529,  0.5529],
            #           [-1.0000, -1.0000, -0.9922,  ...,  0.5529,  0.5529,  0.5529],
            #           [-1.0000, -0.9922, -0.9843,  ...,  0.5529,  0.5529,  0.5529],
            #           ...,
            #           [ 0.4118,  0.4275,  0.4118,  ..., -0.7490, -0.7333, -0.7020],
            #           [ 0.4196,  0.4039,  0.4196,  ..., -0.7020, -0.7804, -0.7255],
            #           [ 0.4039,  0.4196,  0.4588,  ..., -0.6784, -0.7333, -0.6941]],

            #          [[-0.9529, -0.9686, -0.9843,  ...,  0.5843,  0.5843,  0.5843],
            #           [-0.9529, -0.9686, -0.9843,  ...,  0.5843,  0.5843,  0.5843],
            #           [-0.9608, -0.9686, -0.9765,  ...,  0.5843,  0.5843,  0.5843],
            #           ...,
            #           [ 0.4431,  0.4588,  0.4431,  ..., -0.8510, -0.8353, -0.8039],
            #           [ 0.4510,  0.4353,  0.4510,  ..., -0.8039, -0.8824, -0.8275],
            #           [ 0.4353,  0.4510,  0.4902,  ..., -0.7725, -0.8275, -0.7882]],

            #          [[-0.9843, -1.0000, -1.0000,  ...,  0.6549,  0.6549,  0.6549],
            #           [-0.9843, -1.0000, -1.0000,  ...,  0.6549,  0.6549,  0.6549],
            #           [-0.9922, -1.0000, -0.9922,  ...,  0.6549,  0.6549,  0.6549],
            #           ...,
            #           [ 0.5294,  0.5451,  0.5294,  ..., -0.9216, -0.8980, -0.8667],
            #           [ 0.5373,  0.5216,  0.5373,  ..., -0.8824, -0.9529, -0.8980],
            #           [ 0.5216,  0.5373,  0.5765,  ..., -0.8667, -0.9216, -0.8824]]]]), 'path': ['../../Celeb_subset/train/images/8516.jpg']}
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
        epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Exemplo n.º 9
0
def do_train(opt):
    # print options to help debugging
    print(' '.join(sys.argv))

    # load the dataset
    dataloader = data.create_dataloader(opt)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    if opt.train_eval:
        # val_opt = TestOptions().parse()
        original_flip = opt.no_flip
        opt.no_flip = True
        opt.phase = 'test'
        opt.isTrain = False
        dataloader_val = data.create_dataloader(opt)
        val_visualizer = Visualizer(opt)
        # # create a webpage that summarizes the all results
        web_dir = os.path.join(opt.results_dir, opt.name,
                            '%s_%s' % (opt.phase, opt.which_epoch))
        webpage = html.HTML(web_dir,
                            'Experiment = %s, Phase = %s, Epoch = %s' %
                            (opt.name, opt.phase, opt.which_epoch))
        opt.phase = 'train'
        opt.isTrain = True
        opt.no_flip = original_flip
        # process for calculate FID scores
        from inception import InceptionV3
        from fid_score import calculate_fid_given_paths
        import pathlib
        # define the inceptionV3
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[opt.eval_dims]
        eval_model = InceptionV3([block_idx]).cuda()
        # load real images distributions on the training set
        mu_np_root = os.path.join('datasets/train_mu_si',opt.dataset_mode,'m.npy')
        st_np_root = os.path.join('datasets/train_mu_si',opt.dataset_mode,'s.npy')
        m0, s0 = np.load(mu_np_root), np.load(st_np_root)
        # load previous best FID
        if opt.continue_train:
            fid_record_dir = os.path.join(opt.checkpoints_dir, opt.name, 'fid.txt')
            FID_score, _ = np.loadtxt(fid_record_dir, delimiter=',', dtype=float)
        else:
            FID_score = 1000
    else:
        FID_score = 1000      

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                if opt.train_eval:
                    visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                    losses, iter_counter.time_per_iter, FID_score)
                else:
                    visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                    losses, iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

            # if iter_counter.needs_displaying():
            #     visuals = OrderedDict([('input_label', data_i['label']),
            #                            ('synthesized_image', trainer.get_latest_generated()),
            #                            ('real_image', data_i['image'])])
            #     visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter(FID_score)

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.eval_epoch_freq == 0 and opt.train_eval:
            # generate fake image
            trainer.pix2pix_model.eval()
            print('start evalidation .... ')
            if opt.use_vae:
                flag = True
                opt.use_vae = False
            else:
                flag = False
            for i, data_i in enumerate(dataloader_val):
                if data_i['label'].size()[0] != opt.batchSize:
                    if opt.batchSize > 2*data_i['label'].size()[0]:
                        print('batch size is too large')
                        break
                    data_i = repair_data(data_i, opt.batchSize)
                generated = trainer.pix2pix_model(data_i, mode='inference')
                img_path = data_i['path']
                for b in range(generated.shape[0]):
                    tmp = tensor2im(generated[b])
                    visuals = OrderedDict([('input_label', data_i['label'][b]),
                                        ('synthesized_image', generated[b])])
                    val_visualizer.save_images(webpage, visuals, img_path[b:b + 1])
            webpage.save()
            trainer.pix2pix_model.train()
            if flag:
                opt.use_vae = True
            # cal fid score
            fake_path = pathlib.Path(os.path.join(web_dir, 'images/synthesized_image/'))
            files = list(fake_path.glob('*.jpg')) + list(fake_path.glob('*.png'))
            m1, s1 = calculate_activation_statistics(files, eval_model, 1, opt.eval_dims, True, images=None)
            fid_value = calculate_frechet_distance(m0, s0, m1, s1)
            visualizer.print_eval_fids(epoch, fid_value, FID_score)
            # save the best model if necessary
            if fid_value < FID_score:
                FID_score = fid_value
                trainer.save('best')

        if epoch % opt.save_epoch_freq == 0 or \
        epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')