Esempio n. 1
0
def cyclgan_train(opt, cycle_gan: CycleGANModel, train_loader, writer_dict):
    cycle_gan.train()

    writer = writer_dict['writer']
    total_iters = 0
    t_data = 0.0

    for epoch in trange(opt.epoch_count,
                        opt.n_epochs + opt.n_epochs_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        train_steps = writer_dict['train_steps']
        for i, data in enumerate(train_loader):
            iter_start_time = time.time()

            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            cycle_gan.set_input(data)
            cycle_gan.optimize_parameters()

            if (i + 1) % opt.print_freq == 0:
                losses = cycle_gan.get_current_losses()
                t_comp = (time.time() - iter_start_time)
                message = "GAN: [Ep: %d/%d]" % (epoch, opt.n_epochs +
                                                opt.n_epochs_decay)
                message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % (
                    epoch_iter, len(train_loader), t_comp, t_data)
                for k, v in losses.items():
                    message += '[%s: %.3f]' % (k, v)
                tqdm.write(message)

            if (total_iters + 1) % opt.display_freq == 0:
                cycle_gan.compute_visuals()
                save_current_results(opt, cycle_gan.get_current_visuals(),
                                     train_steps)

            if (total_iters + 1) % opt.save_latest_freq == 0:
                tqdm.write(
                    'saving the latest model (epoch %d, total_iters %d)' %
                    (epoch, total_iters))
                save_suffix = 'latest'
                cycle_gan.save_networks(save_suffix)

            iter_data_time = time.time()

        if (epoch + 1) % opt.save_epoch_freq == 0:
            cycle_gan.save_networks('latest')
            cycle_gan.save_networks(epoch)

        tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' %
                   (epoch, opt.n_epochs + opt.n_epochs_decay,
                    time.time() - epoch_start_time))

        writer.add_scalars('Train/discriminator', {
            "A": float(cycle_gan.loss_D_A),
            "B": float(cycle_gan.loss_D_B),
        }, train_steps)
        writer.add_scalars('Train/generator', {
            "A": float(cycle_gan.loss_G_A),
            "B": float(cycle_gan.loss_G_B),
        }, train_steps)
        writer.add_scalars(
            'Train/cycle', {
                "A": float(cycle_gan.loss_cycle_A),
                "B": float(cycle_gan.loss_cycle_B),
            }, train_steps)
        writer.add_scalars('Train/idt', {
            "A": float(cycle_gan.loss_idt_A),
            "B": float(cycle_gan.loss_idt_B),
        }, train_steps)

        writer_dict['train_steps'] += 1
        cycle_gan.update_learning_rate()
Esempio n. 2
0
        for i, data in enumerate(dataset):

            visualizer.reset()
            total_steps += 1
            epoch_iter += 1

            model.set_input(data)
            model.optimize_parameters()

            img_path = model.get_image_paths()
            short_path = ntpath.basename(''.join(img_path))
            '''
            # 决定每轮输出显示哪一张图片
            '''
            if epoch_iter == opt.display_num:
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, short_path, True)
            '''
            # 计算loss值
            '''
            errors, sparse_c_loss = model.get_current_errors()
            visualizer.print_current_errors(epoch, epoch_iter, short_path,
                                            errors)

            # get sparse_c_loss_points
            sparse_c_loss_points.append(sparse_c_loss)

            depth_errors = model.get_depth_errors()
            visualizer.print_depth_errors(epoch, epoch_iter, short_path,
                                          depth_errors)
Esempio n. 3
0
def cyclgan_train(opt, cycle_gan: CycleGANModel,
                  cycle_controller: CycleControllerModel, train_loader,
                  g_loss_history: RunningStats, d_loss_history: RunningStats,
                  writer_dict):
    cycle_gan.train()
    cycle_controller.eval()

    dynamic_reset = False
    writer = writer_dict['writer']
    total_iters = 0
    t_data = 0.0

    for epoch in range(opt.shared_epoch):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        train_steps = writer_dict['train_steps']
        for i, data in enumerate(train_loader):
            iter_start_time = time.time()

            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            cycle_controller.forward()

            cycle_gan.set_input(data)
            cycle_gan.optimize_parameters()

            g_loss_history.push(cycle_gan.loss_G.item())
            d_loss_history.push(cycle_gan.loss_D_A.item() +
                                cycle_gan.loss_D_B.item())

            if (i + 1) % opt.print_freq == 0:
                losses = cycle_gan.get_current_losses()
                t_comp = (time.time() - iter_start_time)
                message = "GAN: [Ep: %d/%d]" % (epoch, opt.shared_epoch)
                message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % (
                    epoch_iter, len(train_loader), t_comp, t_data)
                for k, v in losses.items():
                    message += '[%s: %.3f]' % (k, v)
                tqdm.write(message)

            if (total_iters + 1) % opt.display_freq == 0:
                cycle_gan.compute_visuals()
                save_current_results(opt, cycle_gan.get_current_visuals(),
                                     train_steps)

            if g_loss_history.is_full():
                if g_loss_history.get_var() < opt.dynamic_reset_threshold \
                        or d_loss_history.get_var() < opt.dynamic_reset_threshold:
                    dynamic_reset = True
                    tqdm.write("=> dynamic resetting triggered")
                    g_loss_history.clear()
                    d_loss_history.clear()
                    return dynamic_reset

            if (
                    total_iters + 1
            ) % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                tqdm.write(
                    'saving the latest model (epoch %d, total_iters %d)' %
                    (epoch, total_iters))
                save_suffix = 'latest'
                # cycle_gan.save_networks(train_steps)

            iter_data_time = time.time()

        if (epoch + 1) % opt.save_epoch_freq == 0:
            cycle_gan.save_networks('latest')
            # cycle_gan.save_networks(train_steps)

        tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' %
                   (epoch, opt.n_epochs + opt.n_epochs_decay,
                    time.time() - epoch_start_time))

        writer.add_scalars('Train/discriminator', {
            "A": float(cycle_gan.loss_D_A),
            "B": float(cycle_gan.loss_D_B),
        }, train_steps)
        writer.add_scalars('Train/generator', {
            "A": float(cycle_gan.loss_G_A),
            "B": float(cycle_gan.loss_G_B),
        }, train_steps)
        writer.add_scalars(
            'Train/cycle', {
                "A": float(cycle_gan.loss_cycle_A),
                "B": float(cycle_gan.loss_cycle_B),
            }, train_steps)
        writer.add_scalars('Train/idt', {
            "A": float(cycle_gan.loss_idt_A),
            "B": float(cycle_gan.loss_idt_B),
        }, train_steps)

        writer_dict['train_steps'] += 1

    return dynamic_reset
Esempio n. 4
0
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip

data_loader = DataLoader(opt)
dataset = data_loader.load_data()
model = CycleGANModel()
model.initialize(opt)
visualizer = Visualizer(opt)

if __name__ == '__main__':
    root_dir = os.path.join(opt.result_root_dir, opt.variable)
    web_dir = os.path.join(root_dir, opt.variable_value, opt.phase)
    webpage = html.HTML(web_dir,
                        'Experiment = GAN2C, Phase = test, Epoch = latest')
    # test
    for i, data in enumerate(dataset):
        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()

        img_path = model.get_image_paths()
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)

        short_path = ntpath.basename(''.join(img_path))
        test_depth_errors = model.get_depth_errors()
        visualizer.test_depth_errors(i + 1, short_path, test_depth_errors)

    webpage.save()
Esempio n. 5
0
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True  # no flip; comment this line if results on flipped images are needed.
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    model = CycleGANModel(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    # create results dir
    image_dir = create_results_dir(opt)
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    if opt.eval:
        model.eval()
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
        model.test()  # run inference
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()  # get image paths
        if i % 5 == 0:  # save images to an HTML file
            print('processing (%04d)-th image... %s' % (i, img_path))
        save_images(opt,
                    image_dir,
                    visuals,
                    img_path,
                    aspect_ratio=opt.aspect_ratio,
                    width=opt.display_winsize)