コード例 #1
0
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory,
                                hyperparameters=config) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size,
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='configs/edges2handbags_folder.yaml',
                        help='Path to the config file.')
    parser.add_argument('--output_path',
                        type=str,
                        default='.',
                        help="outputs path")
    #resume option => [, default='730000']
    parser.add_argument("--resume", default='150000', action="store_true")
    parser.add_argument('--trainer',
                        type=str,
                        default='MUNIT',
                        help="MUNIT|UNIT")
    opts = parser.parse_args()

    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader
    if opts.trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts.trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    train_display_images_a = torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    train_writer = tensorboardX.SummaryWriter(
        os.path.join(opts.output_path + "/logs", model_name))
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory,
                                hyperparameters=config) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size,
                              image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size,
                              image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')

            if (iterations + 1) % config['image_display_iter'] == 0:
                with torch.no_grad():
                    image_outputs = trainer.sample(train_display_images_a,
                                                   train_display_images_b)
                write_2images(image_outputs, display_size, image_directory,
                              'train_current')

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')
コード例 #3
0
def main(argv):
    (opts, args) = parser.parse_args(argv)
    cudnn.benchmark = True
    model_name = os.path.splitext(os.path.basename(opts.config))[0]

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']

    # Setup model and data loader
    trainer = MUNIT_Trainer(config)
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    test_display_images_a = Variable(torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda(),
                                     volatile=True)
    test_display_images_b = Variable(torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda(),
                                     volatile=True)
    train_display_images_a = Variable(torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda(),
                                      volatile=True)
    train_display_images_b = Variable(torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda(),
                                      volatile=True)

    # Setup logger and output folders
    train_writer = tensorboard.SummaryWriter(os.path.join(
        opts.log, model_name))
    output_directory = os.path.join(opts.outputs, model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(izip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = Variable(images_a.cuda()), Variable(
                images_b.cuda())

            # Main training code
            trainer.dis_update(images_a, images_b, config)
            trainer.gen_update(images_a, images_b, config)

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                # Test set images
                image_outputs = trainer.sample(test_display_images_a,
                                               test_display_images_b)
                write_images(
                    image_outputs, display_size,
                    '%s/gen_test%08d.jpg' % (image_directory, iterations + 1))
                # Train set images
                image_outputs = trainer.sample(train_display_images_a,
                                               train_display_images_b)
                write_images(
                    image_outputs, display_size,
                    '%s/gen_train%08d.jpg' % (image_directory, iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')
            if (iterations + 1) % config['image_save_iter'] == 0:
                image_outputs = trainer.sample(test_display_images_a,
                                               test_display_images_b)
                write_images(image_outputs, display_size,
                             '%s/gen.jpg' % image_directory)

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                return
コード例 #4
0
iterations = trainer.resume(checkpoint_directory,
                            hyperparameters=config) if opts.resume else 0
while True:
    for it, (images_target_a,
             images_target_b) in enumerate(zip(train_loader_a,
                                               train_loader_b)):
        trainer.update_learning_rate()
        images_a = images_target_a[0].cuda().detach()
        target_a = images_target_a[1].cuda().detach()
        images_b = images_target_b[0].cuda().detach()
        target_b = images_target_b[1].cuda().detach()
        ids_b = images_target_b[2]
        # Main training code
        with Timer("Elapsed time in update: %f"):
            trainer.dis_update(images_a, images_b, config)
            trainer.gen_update(images_a, images_b, config, target_a,
                               iterations)

            if iterations >= config['train_seg_iters']:
                trainer.seg_update(images_a, images_b, target_a, target_b)

        # Dump training stats in log file
        if (iterations) % config['log_iter'] == 0:
            print('Iteration: %08d/%08d' % (iterations, max_iter))
            losses = write_loss(iterations, trainer, train_writer)

        # Write images
        if (iterations) % config['image_save_iter'] == 0:
            with torch.no_grad():
                train_image_outputs = trainer.sample(train_display_images_a,
                                                     train_display_images_b)
                test_image_outputs = trainer.sample(test_display_images_a,
コード例 #5
0
                 (images_as, images_bs, mask_s, sem_a, sem_b)) in enumerate(
                     zip(train_loader_a_w_mask, train_loader_b_w_mask,
                         synthetic_loader)):
            with Timer("Elapsed time in update s: %f"):
                trainer.update_learning_rate()
                images_a, images_b = images_a.cuda().detach(), images_b.cuda(
                ).detach()
                mask_a, mask_b = mask_a.cuda().detach(), mask_b.cuda().detach()
                images_as, images_bs, mask_s = images_as.cuda().detach(
                ), images_bs.cuda().detach(), mask_s.cuda().detach()

                # Main training code
                trainer.dis_update(images_a, images_b, config, comet_exp)
                #Gen update
                if (iterations + 1) % config["ratio_disc_gen"] == 0:
                    trainer.gen_update(images_a, images_b, config, mask_a,
                                       mask_b, comet_exp)
                #Domain classifier update
                if config["domain_adv_w"] > 0:
                    trainer.domain_classifier_update(images_a, images_b,
                                                     config, comet_exp)
                #Domain classifier s,r update
                if trainer.use_classifier_sr and (
                        iterations +
                        1) % config["adaptation"]["classif_frequency"] == 0:
                    print(iterations + 1)
                    trainer.domain_classifier_sr_update(
                        images_a, images_b, False,
                        config["adaptation"]["dfeat_lambda"], iterations + 1,
                        comet_exp)

                #Output domain classifier s,r update
コード例 #6
0
ファイル: train.py プロジェクト: adrienju/MUNIT
              if opts.resume else 0)

if config["semantic_w"] == 0:
    while True:
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config, comet_exp)
                if (iterations + 1) % config["ratio_disc_gen"] == 0:
                    trainer.gen_update(images_a,
                                       images_b,
                                       config,
                                       comet_exp=comet_exp)
                if config["domain_adv_w"] > 0:
                    trainer.domain_classifier_update(images_a, images_b,
                                                     config, comet_exp)
                torch.cuda.synchronize()

            # Write images
            if (iterations + 1) % config["image_save_iter"] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(
                    test_image_outputs,
コード例 #7
0
# Start training
iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
while True:
    for it, (data_a, data_b) in enumerate(zip(train_loader_a, train_loader_b)): # to iterate along both lists
        trainer.update_learning_rate()
        images_a = data_a
        images_b = data_b
        images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()

        # Main training code
        trainer.dis_update(images_a, images_b, config)        
        if config['dis']['gan_type'] == 'ralsgan':
            images_rand_a = random_sample_a.__next__()
            images_rand_b = random_sample_b.__next__()
            images_rand_a, images_rand_b = Variable(images_rand_a.cuda()), Variable(images_rand_b.cuda())
            trainer.gen_update(images_a, images_b, config, images_rand_a, images_rand_b)
        else:
            trainer.gen_update(images_a, images_b, config)
        torch.cuda.synchronize()
        trainer.update_iter()

        # Dump training stats in log file
        if (iterations + 1) % config['log_iter'] == 0:
            print("Model: %s, Iteration: %08d/%08d" % (model_name, iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer)

        # Training logs
        if (iterations + 1) % config['image_save_iter'] == 0:
            iter_directory = os.path.join(output_directory+'/images', 'iter_'+str(iterations + 1).zfill(8))
            if not os.path.exists(iter_directory):
                print("Creating directory: {}".format(iter_directory))
コード例 #8
0
def main():
    from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
    import argparse
    from torch.autograd import Variable
    from trainer import MUNIT_Trainer, UNIT_Trainer
    import torch.backends.cudnn as cudnn
    import torch

    # try:
    #     from itertools import izip as zip
    # except ImportError:  # will be 3.x series
    #     pass

    import os
    import sys
    import tensorboardX
    import shutil

    os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='configs/edges2handbags_folder.yaml',
                        help='Path to the config file.')
    parser.add_argument('--output_path',
                        type=str,
                        default='.',
                        help="outputs path")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument('--trainer',
                        type=str,
                        default='MUNIT',
                        help="MUNIT|UNIT")
    opts = parser.parse_args()

    cudnn.benchmark = True
    '''
    Note: https://www.pytorchtutorial.com/when-should-we-set-cudnn-benchmark-to-true/
        大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题
        1.  如果网络的输入数据维度或类型上变化不大,设置  torch.backends.cudnn.benchmark = true  可以增加运行效率;
        2.  如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。
    '''

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader
    if opts.trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts.trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    train_display_images_a = torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    train_writer = tensorboardX.SummaryWriter(
        os.path.join(opts.output_path + "/logs", model_name))
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory,
                                hyperparameters=config) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size,
                              image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size,
                              image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')

            if (iterations + 1) % config['image_display_iter'] == 0:
                with torch.no_grad():
                    image_outputs = trainer.sample(train_display_images_a,
                                                   train_display_images_b)
                write_2images(image_outputs, display_size, image_directory,
                              'train_current')

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')
コード例 #9
0
        labels_2 = labels_list[index_2]

        use_1 = use_list[index_1]
        use_2 = use_list[index_2]

        images_1, images_2 = Variable(images_1.cuda()), Variable(images_2.cuda())

        # Main training code.
        if (ep + 1) <= int(0.75 * epochs):

            # If in Full Training mode.
            trainer.set_sup_trainable(True)
            trainer.set_gen_trainable(True)

            trainer.dis_update(images_1, images_2, index_1, index_2, config)
            trainer.gen_update(images_1, images_2, index_1, index_2, config)

        else:

            # If in Supervision Tuning mode.
            trainer.set_sup_trainable(True)
            trainer.set_gen_trainable(False)

        labels_1 = labels_1.to(dtype=torch.long)
        labels_1[labels_1 > 0] = 1
        labels_1 = Variable(labels_1.cuda(), requires_grad=False)

        labels_2 = labels_2.to(dtype=torch.long)
        labels_2[labels_2 > 0] = 1
        labels_2 = Variable(labels_2.cuda(), requires_grad=False)
コード例 #10
0
        trainer.update_learning_rate()
        images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()
        # labels_a, labels_b = labels_a.cuda().detach(), labels_b.cuda().detach()
        labels_b = labels_b.cuda().detach()
        images_a_limited, labels_a_limited = images_a_limited.cuda().detach(
        ), labels_a_limited.cuda().detach()

        with Timer("Elapsed time in update: %f"):
            # Main training code
            # time_start_iter = time()

            trainer.dis_update(images_a, images_b, config)
            # time_dis = time()
            # print(f'Dis: {time_dis - time_start_iter}', end=" ")

            trainer.gen_update(images_a, [images_b, labels_b], config,
                               [images_a_limited, labels_a_limited])
            # time_gen = time()
            # print(f'Gen: {time_gen - time_dis}', end=" ")

            trainer.cla_update([images_a_limited, labels_a_limited],
                               [images_b, labels_b])
            # time_con_cla = time()
            # print(f'Cla: {time_con_cla - time_gen}')

            torch.cuda.synchronize()

        # Dump training stats in log file
        if (iterations + 1) % config['log_iter'] == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer)
            sys.exit('Stop it!!')
コード例 #11
0
ファイル: train.py プロジェクト: phonx/MUNIT
model_name = os.path.splitext(os.path.basename(opts.config))[0]
train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name))
output_directory = os.path.join(opts.output_path + "/outputs", model_name)
checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

# Start training
iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
while True:
    for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
        trainer.update_learning_rate()
        images_a, images_b = Variable(images_a.cuda()), Variable(images_b.cuda())

        # Main training code
        trainer.dis_update(images_a, images_b, config)
        trainer.gen_update(images_a, images_b, config)

        # Dump training stats in log file
        if (iterations + 1) % config['log_iter'] == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer)

        # Write images
        if (iterations + 1) % config['image_save_iter'] == 0:
            # Test set images
            image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
            write_2images(image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
            # Train set images
            image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
            # HTML
コード例 #12
0
            os.path.join(output_directory,
                         'config.yaml'))  # copy config file to output folder

# Start training
iterations = trainer.resume(checkpoint_directory,
                            hyperparameters=config) if opts.resume else 0
while True:
    for it, (images_a,
             images_b) in enumerate(zip(train_loader_a, train_loader_b)):
        trainer.update_learning_rate()
        images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()

        with Timer("update_time: %f"):
            # Main training code
            loss_dis_total = trainer.dis_update(images_a, images_b, config)
            loss_gen_total, loss_recon_x, loss_recon_s, loss_recon_c, loss_cycrecon, loss_vgg = trainer.gen_update(
                images_a, images_b, config)
            torch.cuda.synchronize()
        # loss_dis_total = trainer.dis_update(images_a, images_b, config)
        # loss_gen_total, loss_recon_x, loss_recon_s, loss_recon_c, loss_cycrecon, loss_vgg= trainer.gen_update(images_a, images_b, config)
        print(
            " | dis_los: %9f | gen_los_total: %9f | recon_x_los: %9f | recon_s_los: %9f | recon_c_los: %9f | cycle_los: %9f | vgg_los: %9f "
            % (loss_dis_total, loss_gen_total, loss_recon_x, loss_recon_s,
               loss_recon_c, loss_cycrecon, loss_vgg))

        # Dump training stats in log file
        if (iterations + 1) % config['log_iter'] == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer)

        # Write images
        if (iterations + 1) % config['image_save_iter'] == 0:
コード例 #13
0
ファイル: train.py プロジェクト: WangJerry95/MUNIT
train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name))
output_directory = os.path.join(opts.output_path + "/outputs", model_name)
checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

# Start training
iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
while True:
    for it, (images_a, (images_b,label_b)) in enumerate(zip(train_loader_a, train_loader_b)):
        trainer.update_learning_rate()
        images_a, images_b, label_b= images_a.cuda().detach(), images_b.cuda().detach(), label_b.cuda().detach()

        with Timer("Elapsed time in update: %f"):
            # Main training code
            trainer.dis_update(images_a, images_b, config, label_b=label_b)
            trainer.gen_update(images_a, images_b, config, label_b=label_b)
            torch.cuda.synchronize()

        # Dump training stats in log file
        if (iterations + 1) % config['log_iter'] == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer)

        # Write images
        if (iterations + 1) % config['image_save_iter'] == 0:
            with torch.no_grad():
                test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
                train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
            write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
            # HTML
コード例 #14
0
def main(argv):
    (opts, args) = parser.parse_args(argv)
    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']

    # Setup logger and output folders
    output_subfolders = prepare_logging_folders(config['output_root'],
                                                config['experiment_name'])
    logger = create_logger(
        os.path.join(output_subfolders['logs'], 'train_log.log'))
    shutil.copy(opts.config,
                os.path.join(
                    output_subfolders['logs'],
                    'config.yaml'))  # copy config file to output folder

    tb_logger = tensorboard_logger.Logger(output_subfolders['logs'])

    logger.info('============ Initialized logger ============')
    logger.info('Config File: {}'.format(opts.config))

    # Setup model and data loader
    trainer = MUNIT_Trainer(config, opts)
    trainer.cuda()
    loaders = get_all_data_loaders(config)
    val_display_images = next(iter(loaders['val']))
    logger.info('Test images: {}'.format(val_display_images['A_paths']))

    # Start training
    iterations = trainer.resume(opts.model_path,
                                hyperparameters=config) if opts.resume else 0

    while True:
        for it, images in enumerate(loaders['train']):
            trainer.update_learning_rate()
            images_a = images['A']
            images_b = images['B']

            images_a, images_b = Variable(images_a.cuda()), Variable(
                images_b.cuda())

            # Main training code
            trainer.dis_update(images_a, images_b, config)
            trainer.gen_update(images_a, images_b, config)

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                for tag, value in trainer.loss.items():
                    tb_logger.scalar_summary(tag, value, iterations)

                val_output_imgs = trainer.sample(
                    Variable(val_display_images['A'].cuda()),
                    Variable(val_display_images['B'].cuda()))

                tb_imgs = []
                for imgs in val_output_imgs.values():
                    tb_imgs.append(torch.cat(torch.unbind(imgs, 0), dim=2))

                tb_logger.image_summary(list(val_output_imgs.keys()), tb_imgs,
                                        iterations)

            if (iterations + 1) % config['print_iter'] == 0:
                logger.info(
                    "Iteration: {:08}/{:08} Discriminator Loss: {:.4f} Generator Loss: {:.4f}"
                    .format(iterations + 1, max_iter, trainer.loss['D/total'],
                            trainer.loss['G/total']))

            # Write images
            # if (iterations + 1) % config['image_save_iter'] == 0:
            #     val_output_imgs = trainer.sample(
            #         Variable(val_display_images['A'].cuda()),
            #         Variable(val_display_images['B'].cuda()))
            #
            #     for key, imgs in val_output_imgs.items():
            #         key = key.replace('/', '_')
            #         write_images(imgs, config['display_size'], '{}/{}_{:08}.jpg'.format(output_subfolders['images'], key, iterations+1))
            #
            #     logger.info('Saved images to: {}'.format(output_subfolders['images']))

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(output_subfolders['models'], iterations)

            iterations += 1
            if iterations >= max_iter:
                return