コード例 #1
0
ファイル: test.py プロジェクト: zergey/MUNIT
def main(argv):
    (opts, args) = parser.parse_args(argv)
    torch.manual_seed(opts.seed)
    torch.cuda.manual_seed(opts.seed)
    if not os.path.exists(opts.output_folder):
        os.makedirs(opts.output_folder)

    # Load experiment setting
    config = get_config(opts.config)
    style_dim = config['gen']['style_dim']
    opts.num_style = 1 if opts.style != '' else opts.num_style

    # Setup model and data loader
    trainer = MUNIT_Trainer(config)
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
    trainer.cuda()
    trainer.eval()
    encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode  # encode function
    style_encode = trainer.gen_b.encode if opts.a2b else trainer.gen_a.encode  # encode function
    decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode  # decode function

    transform = transforms.Compose([
        transforms.Resize(config['new_size']),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    image = Variable(transform(Image.open(
        opts.input).convert('RGB')).unsqueeze(0).cuda(),
                     volatile=True)
    style_image = Variable(transform(Image.open(
        opts.style).convert('RGB')).unsqueeze(0).cuda(),
                           volatile=True) if opts.style != '' else None

    # Start testing
    style_rand = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(),
                          volatile=True)
    content, _ = encode(image)
    if opts.style != '':
        _, style = style_encode(style_image)
    else:
        style = style_rand
    for j in range(opts.num_style):
        s = style[j].unsqueeze(0)
        outputs = decode(content, s)
        outputs = (outputs + 1) / 2.
        path = os.path.join(opts.output_folder, 'output{:03d}.jpg'.format(j))
        vutils.save_image(outputs.data, path, padding=0, normalize=True)
    if not opts.output_only:
        # also save input images
        vutils.save_image(image.data,
                          os.path.join(opts.output_folder, 'input.jpg'),
                          padding=0,
                          normalize=True)
コード例 #2
0
ファイル: test_batch.py プロジェクト: zergey/MUNIT
def main(argv):
    (opts, args) = parser.parse_args(argv)
    torch.manual_seed(opts.seed)
    torch.cuda.manual_seed(opts.seed)
    if not os.path.exists(opts.output_folder):
        os.makedirs(opts.output_folder)

    # Load experiment setting
    config = get_config(opts.config)
    input_dim = config['new_size'] if opts.a2b else config['input_dim_b']
    style_dim = config['gen']['style_dim']

    # Setup model and data loader
    data_loader = get_data_loader_folder(opts.input_folder,
                                         1,
                                         False,
                                         input_dim == 1,
                                         crop=False)
    trainer = MUNIT_Trainer(config)
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
    trainer.cuda()
    trainer.eval()
    encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode  # encode function
    decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode  # decode function

    # Start testing
    style_fixed = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(),
                           volatile=True)
    for i, images in enumerate(data_loader):
        images = Variable(images.cuda(), volatile=True)
        content, _ = encode(images)
        style = style_fixed if opts.synchronized else Variable(
            torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True)
        for j in range(opts.num_style):
            s = style[j].unsqueeze(0)
            outputs = decode(content, s)
            outputs = (outputs + 1) / 2.
            path = os.path.join(opts.output_folder,
                                'input{:03d}_output{:03d}.jpg'.format(i, j))
            vutils.save_image(outputs.data, path, padding=0, normalize=True)
        if not opts.output_only:
            # also save input images
            vutils.save_image(images.data,
                              os.path.join(opts.output_folder,
                                           'input{:03d}.jpg'.format(i)),
                              padding=0,
                              normalize=True)
コード例 #3
0
opts = parser.parse_args()

torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
if not os.path.exists(opts.output_folder):
    os.makedirs(opts.output_folder)

# Load experiment setting
opts.num_style = 1 if opts.style != '' else opts.num_style

if opts.trainer == 'MUNIT':
    from trainer import MUNIT_Trainer
    from utils import get_all_data_loaders, get_config
    config = get_config(opts.config)
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    from trainer import UNIT_Trainer
    from utils import get_config
    config = get_config(opts.config)
    trainer = UNIT_Trainer(config)
elif opts.trainer == 'CDUNIT':
    from cd_trainer import CDUNIT_Trainer
    from cd_utils import get_all_data_loaders, get_config
    config = get_config(opts.config)
    trainer = CDUNIT_Trainer(config)
elif opts.trainer == 'SECUNIT':
    from secunit_trainer import SECUNIT_Trainer
    from secunit_utils import get_all_data_loaders, get_config
    config = get_config(opts.config)
    trainer = SECUNIT_Trainer(config)
コード例 #4
0
ファイル: main.py プロジェクト: VentusXu09/MyGAN
from trainer import MUNIT_Trainer

import yaml

from utils import get_config, get_loader, prepare_sub_folder, Timer, write_loss, write_2images

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config.yaml', help='Path to the config file.')
parser.add_argument('--output_path', type=str, default='.', help="outputs path")
opts = parser.parse_args()

config = get_config(opts.config)
max_iter = config['max_iter']
display_size = config['display_size']

trainer = MUNIT_Trainer(config)

trainer.cuda()

data_loader_a = get_loader(os.path.join(config['data_root'], 'b'),
                        config['crop_image_height'], config['new_size'], config['batch_size'],
                        'afhq', 'train', 1)
data_loader_b = get_loader(os.path.join(config['data_root'], 'a'),
                        config['crop_image_height'], config['new_size'], config['batch_size'],
                        'afhq', 'train', 1)


train_display_images_a = torch.stack([data_loader_a.dataset[i][0] for i in range(display_size)]).cuda()
train_display_images_b = torch.stack([data_loader_b.dataset[i][0] for i in range(display_size)]).cuda()

# Setup logger and output folders
コード例 #5
0
ファイル: train.py プロジェクト: EdisonCCL/IOSUDA
config['snapshot_dir']=opts.snapshot_dir
config['snapshot_save_iter']=opts.snapshot_save_iter
config['sample_C']=opts.sample_C
config['trim']=opts.trim
config['sample_B']=opts.sample_B
config['sample_D']=opts.sample_D
config['sample_A']=opts.sample_A
config['batch_size']=opts.batch_size
config['transform_A']=opts.transform_A
config['transform_B']=opts.transform_B
config['transform_C']=opts.transform_C
config['transform_D']=opts.transform_D
config['weight_temp']=opts.weight_temp
config['recon_x_cyc_w']=opts.recon_x_cyc_w
# Setup model and data loader.
trainer = MUNIT_Trainer(config, resume_epoch=opts.resume, snapshot_dir=opts.snapshot_dir)
trainer.cuda()

dataset_letters = eval(opts.dataset_letters)
samples = list()
dataset_probs = list()
augmentation = list()
for i in range(config['n_datasets']):
    samples.append(config['sample_' + dataset_letters[i]])
    augmentation.append(config['transform_' + dataset_letters[i]])

train_loader_list, test_loader_list = get_all_data_loaders(config, config['n_datasets'], samples, augmentation, config['trim'],opts.dataset_letters)

loader_sizes = list()

for l in train_loader_list:
コード例 #6
0
parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")

opts = parser.parse_args()

torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Load experiment setting
config = get_config(opts.config)
opts.num_style = 1 if opts.style != '' else opts.num_style

# Setup model and data loader
config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT")

try:
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
except:
    state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), opts.trainer)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])

trainer.cuda()
trainer.eval()
encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function
コード例 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='configs/edges2handbags_folder.yaml',
                        help='Path to the config file.')
    parser.add_argument('--output_path',
                        type=str,
                        default='.',
                        help="outputs path")
    #resume option => [, default='730000']
    parser.add_argument("--resume", default='150000', action="store_true")
    parser.add_argument('--trainer',
                        type=str,
                        default='MUNIT',
                        help="MUNIT|UNIT")
    opts = parser.parse_args()

    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader
    if opts.trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts.trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    train_display_images_a = torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    train_writer = tensorboardX.SummaryWriter(
        os.path.join(opts.output_path + "/logs", model_name))
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory,
                                hyperparameters=config) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size,
                              image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size,
                              image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')

            if (iterations + 1) % config['image_display_iter'] == 0:
                with torch.no_grad():
                    image_outputs = trainer.sample(train_display_images_a,
                                                   train_display_images_b)
                write_2images(image_outputs, display_size, image_directory,
                              'train_current')

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')
コード例 #8
0
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
# if not os.path.exists(opts.output_folder):
# os.makedirs(opts.output_folder)

# Load experiment setting
config = get_config(opts.config)
config['gen']['no_style_enc'] = opts.no_style_enc

# Setup model and data loader
config['vgg_w'] = 0

if opts.trainer == 'MUNIT':
    config['gen']['style_dim'] = config['gen']['z_num']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")

############## configure checkpoint from output_folder
checkpoint_path = find_latest_model_file(os.path.join(opts.output_folder,
                                                      'checkpoints'),
                                         opts.checkpoint,
                                         keyword='gen')

try:
    state_dict = torch.load(checkpoint_path)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
コード例 #9
0
opts = parser.parse_args()

if comet_exp is not None:
    comet_exp.log_asset(file_data=opts.config, file_name="config.yaml")
    comet_exp.log_parameter("git_hash", opts.git_hash)

cudnn.benchmark = True
# Load experiment setting
config = get_config(opts.config)
max_iter = config["max_iter"]
display_size = config["display_size"]
config["vgg_model_path"] = opts.output_path

# Setup model and data loader
if opts.trainer == "MUNIT":
    trainer = MUNIT_Trainer(config)
elif opts.trainer == "UNIT":
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")
trainer.cuda()

train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
    config)

if config["semantic_w"] > 0:
    train_loader_a_w_mask = get_data_loader_mask_and_im(
        config["data_list_train_a"],
        config["data_list_train_a_seg"],
        config["batch_size"],
        True,
コード例 #10
0
    opts.config = 'configs/handwriting_online.yaml'
    opts.check_files = True
    print("Running on Galois, config {}, check_files {}".format(opts.config, opts.check_files))

cudnn.benchmark = True

# Load experiment setting
config = get_config(opts.config)
print(config)
max_iter = config['max_iter']
display_size = config['display_size']
config['vgg_model_path'] = opts.output_path

# Setup model and data loader
if opts.trainer == 'MUNIT':
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")

#trainer.cuda()

#train_loader_a, train_loader_b, test_loader_a, test_loader_b, folders = get_all_data_loaders(config)
(train_loader_a, tr_a), (train_loader_b, tr_b), (test_loader_a, test_a), (test_loader_b, test_b), folders = get_all_data_loaders_better(config)

if opts.check_files and False:
    print("Checking files...")
    for folder in folders:
        print(folder)
        utils.check_files(folder)
コード例 #11
0
def main():
    from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
    import argparse
    from torch.autograd import Variable
    from trainer import MUNIT_Trainer, UNIT_Trainer
    import torch.backends.cudnn as cudnn
    import torch

    # try:
    #     from itertools import izip as zip
    # except ImportError:  # will be 3.x series
    #     pass

    import os
    import sys
    import tensorboardX
    import shutil

    os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='configs/edges2handbags_folder.yaml',
                        help='Path to the config file.')
    parser.add_argument('--output_path',
                        type=str,
                        default='.',
                        help="outputs path")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument('--trainer',
                        type=str,
                        default='MUNIT',
                        help="MUNIT|UNIT")
    opts = parser.parse_args()

    cudnn.benchmark = True
    '''
    Note: https://www.pytorchtutorial.com/when-should-we-set-cudnn-benchmark-to-true/
        大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题
        1.  如果网络的输入数据维度或类型上变化不大,设置  torch.backends.cudnn.benchmark = true  可以增加运行效率;
        2.  如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。
    '''

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader
    if opts.trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts.trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    train_display_images_a = torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    train_writer = tensorboardX.SummaryWriter(
        os.path.join(opts.output_path + "/logs", model_name))
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory,
                                hyperparameters=config) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size,
                              image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size,
                              image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')

            if (iterations + 1) % config['image_display_iter'] == 0:
                with torch.no_grad():
                    image_outputs = trainer.sample(train_display_images_a,
                                                   train_display_images_b)
                write_2images(image_outputs, display_size, image_directory,
                              'train_current')

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')
コード例 #12
0
def setup(opts):
    generator_checkpoint_path = opts['generator_checkpoint']
    # generator_checkpoint_path = './checkpoints/ffhq2ladiescrop.pt'

    # Load experiment settings
    config = {
        'image_save_iter': 10000,
        'image_display_iter': 100,
        'display_size': 16,
        'snapshot_save_iter': 10000,
        'log_iter': 100,
        'max_iter': 1000000,
        'batch_size': 1,
        'weight_decay': 0.0001,
        'beta1': 0.5,
        'beta2': 0.999,
        'init': 'kaiming',
        'lr': 0.0001,
        'lr_policy': 'step',
        'step_size': 100000,
        'gamma': 0.5,
        'gan_w': 1,
        'recon_x_w': 10,
        'recon_s_w': 1,
        'recon_c_w': 1,
        'recon_x_cyc_w': 10,
        'vgg_w': 0,
        'gen': {
            'dim': 64,
            'mlp_dim': 256,
            'style_dim': 8,
            'activ': 'relu',
            'n_downsample': 2,
            'n_res': 4,
            'pad_type': 'reflect'
        },
        'dis': {
            'dim': 64,
            'norm': 'none',
            'activ': 'lrelu',
            'n_layer': 4,
            'gan_type': 'lsgan',
            'num_scales': 3,
            'pad_type': 'reflect'
        },
        'input_dim_a': 3,
        'input_dim_b': 3,
        'num_workers': 8,
        'new_size': 1024,
        'crop_image_height': 400,
        'crop_image_width': 400,
        'data_root': './datasets/ffhq2ladies/'
    }

    # Setup model and data loader
    trainer = MUNIT_Trainer(config)

    state_dict = torch.load(generator_checkpoint_path)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])

    return {'model': trainer, 'config': config}
コード例 #13
0
            list_images.append(image)
            list_classes_to_take.remove(label)
            # print(list_classes_to_take)
    return torch.stack(list_images).cuda()
    # train_display_images_a = torch.stack([loader.dataset[i][0]]).cuda()


# Load experiment setting
config = get_config(opts.config)
max_iter = config['max_iter']
display_size = config['display_size']
config['vgg_model_path'] = opts.output_path

# Setup model and data loader
if opts.trainer == 'MUNIT':
    trainer = MUNIT_Trainer(config)
# elif opts.trainer == 'UNIT':
#     trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")
trainer.cuda()
train_loader_a, train_loader_a_limited, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
    config)
# train_display_images_a_temp = torch.stack([train_loader_a.dataset[i][0] for i in range(display_size)]).cuda()
# train_display_images_b = torch.stack([train_loader_b.dataset[i][0] for i in range(display_size)]).cuda()
# test_display_images_a = torch.stack([test_loader_a.dataset[i][0] for i in range(display_size)]).cuda()
# test_display_images_b = torch.stack([test_loader_b.dataset[i][0] for i in range(display_size)]).cuda()

train_display_images_a = get_display_images(train_loader_a)
train_display_images_b = get_display_images(train_loader_b)
test_display_images_a = get_display_images(test_loader_a)
コード例 #14
0
ファイル: train.py プロジェクト: phonx/MUNIT
parser.add_argument('--output_path', type=str, default='.', help="outputs path")
parser.add_argument("--resume", action="store_true")
parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")
opts = parser.parse_args()

cudnn.benchmark = True

# Load experiment setting
config = get_config(opts.config)
max_iter = config['max_iter']
display_size = config['display_size']
config['vgg_model_path'] = opts.output_path

# Setup model and data loader
if opts.trainer == 'MUNIT':
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")
trainer.cuda()
train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
train_display_images_a = Variable(torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda(), volatile=True)
train_display_images_b = Variable(torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda(), volatile=True)
test_display_images_a = Variable(torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda(), volatile=True)
test_display_images_b = Variable(torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda(), volatile=True)

# Setup logger and output folders
model_name = os.path.splitext(os.path.basename(opts.config))[0]
train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name))
output_directory = os.path.join(opts.output_path + "/outputs", model_name)
コード例 #15
0
config['no_rec_s'] = opts.no_rec_s

if opts.vgg_w != -1:
    config['vgg_w'] = opts.vgg_w
if opts.cyc_rec_weight != -1:
    config['recon_x_cyc_w'] = opts.cyc_rec_weight
if opts.ne_weight != -1:
    config['loss_eg_weight'] = opts.ne_weight
if opts.rec_c_weight != -1:
    config['recon_c_w'] = opts.rec_c_weight

# Setup model and data loader
if opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
elif opts.trainer == 'MUNIT':
    trainer = MUNIT_Trainer(config)

trainer.cuda()

Dataset = choose_dataset(config['dataset_name'])
dataset = Dataset(config['data_root'], config, split='train')

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=config['batch_size'],
                                         drop_last=True,
                                         shuffle=True,
                                         num_workers=int(
                                             config['num_workers']))

# Setup logger and output folders
model_name = os.path.splitext(os.path.basename(opts.config))[0] + '_vgg_%s_%s' \
コード例 #16
0
ファイル: test.py プロジェクト: lconet/CoDAGANs
                    help="outputs path")
parser.add_argument('--load', type=int, default=400)
parser.add_argument('--snapshot_dir', type=str, default='.')
opts = parser.parse_args()

cudnn.benchmark = True

# Load experiment setting.
config = get_config(opts.config)
display_size = config['display_size']
config['vgg_model_path'] = opts.output_path

# Setup model and data loader.
if config['trainer'] == 'MUNIT':
    trainer = MUNIT_Trainer(config,
                            resume_epoch=opts.load,
                            snapshot_dir=opts.snapshot_dir)
elif config['trainer'] == 'UNIT':
    trainer = UNIT_Trainer(config,
                           resume_epoch=opts.load,
                           snapshot_dir=opts.snapshot_dir)
else:
    sys.exit("Only support MUNIT|UNIT.")
    os.exit()

trainer.cuda()

dataset_letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']
samples = list()
dataset_probs = list()
augmentation = list()
コード例 #17
0

torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
if not os.path.exists(opts.output_folder):
    os.makedirs(opts.output_folder)

# Load experiment setting
config = get_config(opts.config)
opts.num_style = 1 if opts.style != '' else opts.num_style


# Setup model and data loader
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
elif opts.trainer == 'AttnMUNIT':
    trainer = AttnMUNIT_Trainer(config,opts.discriminator,opts.attention,opts.concat_type)
else:
    sys.exit("Only support AttnMUNIT|MUNIT|UNIT")
# elif opts.trainer == 'AdaINAttnMUNIT':
#     trainer = AdaINAttnMUNIT_Trainer(config)

try:
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
except:
    state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), opts.trainer)
コード例 #18
0
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
if not os.path.exists(opts.output_folder):
    os.makedirs(opts.output_folder)

device = 'cuda:%d'%opts.gpu_id
# Load experiment setting
config = get_config(opts.config)
opts.num_style = 1 if opts.style != '' else opts.num_style

# Setup model and data loader
config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config, device)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config, device)
else:
    sys.exit("Only support MUNIT|UNIT")

try:
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
except:
    state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), opts.trainer)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])

trainer.cuda(device)
コード例 #19
0
parser.add_argument('--config', type=str, default='configs/CXR_lungs', help='Path to the config file.')
parser.add_argument('--output_path', type=str, default='.', help="Outputs path.")
parser.add_argument('--resume', type=int, default=-1)
parser.add_argument('--snapshot_dir', type=str, default='.')
opts = parser.parse_args()

cudnn.benchmark = True

# Load experiment setting.
config = get_config(opts.config)
display_size = config['display_size']
config['vgg_model_path'] = opts.output_path

# Setup model and data loader.
if config['trainer'] == 'MUNIT':
    trainer = MUNIT_Trainer(config, resume_epoch=opts.resume, snapshot_dir=opts.snapshot_dir)
elif config['trainer'] == 'UNIT':
    trainer = UNIT_Trainer(config, resume_epoch=opts.resume, snapshot_dir=opts.snapshot_dir)
else:
    sys.exit("Only support MUNIT|UNIT.")
    os.exit()

trainer.cuda()

dataset_letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']
samples = list()
dataset_probs = list()
augmentation = list()
for i in range(config['n_datasets']):
    samples.append(config['sample_' + dataset_letters[i]])
    dataset_probs.append(config['prob_' + dataset_letters[i]])
コード例 #20
0
ファイル: test_batch.py プロジェクト: phonx/MUNIT
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=config['new_size_a'], crop=False)

config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")


state_dict = torch.load(opts.checkpoint)
trainer.gen_a.load_state_dict(state_dict['a'])
trainer.gen_b.load_state_dict(state_dict['b'])
trainer.cuda()
trainer.eval()
encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function
decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function

if opts.trainer == 'MUNIT':
コード例 #21
0
ファイル: train.py プロジェクト: adrienju/MUNIT
opts = parser.parse_args()

if comet_exp is not None:
    comet_exp.log_asset(file_data=opts.config, file_name="config.yaml")
    comet_exp.log_parameter("git_hash", opts.git_hash)

cudnn.benchmark = True
# Load experiment setting
config = get_config(opts.config)
max_iter = config["max_iter"]
display_size = config["display_size"]
config["vgg_model_path"] = opts.output_path

# Setup model and data loader
if opts.trainer == "MUNIT":
    trainer = MUNIT_Trainer(config)

elif opts.trainer == "UNIT":
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")
trainer.cuda()

print(config)

train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
    config)
#If
if config["semantic_w"] > 0:
    train_loader_a_w_mask = get_data_loader_mask_and_im(
        config["data_list_train_a"],
コード例 #22
0
def main(argv):
    (opts, args) = parser.parse_args(argv)
    cudnn.benchmark = True
    model_name = os.path.splitext(os.path.basename(opts.config))[0]

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']

    # Setup model and data loader
    trainer = MUNIT_Trainer(config)
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    test_display_images_a = Variable(torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda(),
                                     volatile=True)
    test_display_images_b = Variable(torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda(),
                                     volatile=True)
    train_display_images_a = Variable(torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda(),
                                      volatile=True)
    train_display_images_b = Variable(torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda(),
                                      volatile=True)

    # Setup logger and output folders
    train_writer = tensorboard.SummaryWriter(os.path.join(
        opts.log, model_name))
    output_directory = os.path.join(opts.outputs, model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(izip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = Variable(images_a.cuda()), Variable(
                images_b.cuda())

            # Main training code
            trainer.dis_update(images_a, images_b, config)
            trainer.gen_update(images_a, images_b, config)

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                # Test set images
                image_outputs = trainer.sample(test_display_images_a,
                                               test_display_images_b)
                write_images(
                    image_outputs, display_size,
                    '%s/gen_test%08d.jpg' % (image_directory, iterations + 1))
                # Train set images
                image_outputs = trainer.sample(train_display_images_a,
                                               train_display_images_b)
                write_images(
                    image_outputs, display_size,
                    '%s/gen_train%08d.jpg' % (image_directory, iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')
            if (iterations + 1) % config['image_save_iter'] == 0:
                image_outputs = trainer.sample(test_display_images_a,
                                               test_display_images_b)
                write_images(image_outputs, display_size,
                             '%s/gen.jpg' % image_directory)

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                return
コード例 #23
0
ファイル: test_batch.py プロジェクト: couver-v/MUNIT
    for param in inception.parameters():
        param.requires_grad = False
    inception_up = nn.Upsample(size=(299, 299), mode='bilinear')

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
data_loader = get_data_loader_folder(opts.input_folder,
                                     1,
                                     False,
                                     new_size=config['new_size_a'],
                                     crop=False)

config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")

try:
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
except:
    state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint),
                                        opts.trainer)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
コード例 #24
0
                        type=str,
                        default='MUNIT',
                        help="MUNIT|UNIT")
    opts = parser.parse_args()

    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader
    if opts.trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts.trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    train_display_images_a = torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda()
コード例 #25
0
def main(argv):
    (opts, args) = parser.parse_args(argv)
    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']

    # Setup logger and output folders
    output_subfolders = prepare_logging_folders(config['output_root'],
                                                config['experiment_name'])
    logger = create_logger(
        os.path.join(output_subfolders['logs'], 'train_log.log'))
    shutil.copy(opts.config,
                os.path.join(
                    output_subfolders['logs'],
                    'config.yaml'))  # copy config file to output folder

    tb_logger = tensorboard_logger.Logger(output_subfolders['logs'])

    logger.info('============ Initialized logger ============')
    logger.info('Config File: {}'.format(opts.config))

    # Setup model and data loader
    trainer = MUNIT_Trainer(config, opts)
    trainer.cuda()
    loaders = get_all_data_loaders(config)
    val_display_images = next(iter(loaders['val']))
    logger.info('Test images: {}'.format(val_display_images['A_paths']))

    # Start training
    iterations = trainer.resume(opts.model_path,
                                hyperparameters=config) if opts.resume else 0

    while True:
        for it, images in enumerate(loaders['train']):
            trainer.update_learning_rate()
            images_a = images['A']
            images_b = images['B']

            images_a, images_b = Variable(images_a.cuda()), Variable(
                images_b.cuda())

            # Main training code
            trainer.dis_update(images_a, images_b, config)
            trainer.gen_update(images_a, images_b, config)

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                for tag, value in trainer.loss.items():
                    tb_logger.scalar_summary(tag, value, iterations)

                val_output_imgs = trainer.sample(
                    Variable(val_display_images['A'].cuda()),
                    Variable(val_display_images['B'].cuda()))

                tb_imgs = []
                for imgs in val_output_imgs.values():
                    tb_imgs.append(torch.cat(torch.unbind(imgs, 0), dim=2))

                tb_logger.image_summary(list(val_output_imgs.keys()), tb_imgs,
                                        iterations)

            if (iterations + 1) % config['print_iter'] == 0:
                logger.info(
                    "Iteration: {:08}/{:08} Discriminator Loss: {:.4f} Generator Loss: {:.4f}"
                    .format(iterations + 1, max_iter, trainer.loss['D/total'],
                            trainer.loss['G/total']))

            # Write images
            # if (iterations + 1) % config['image_save_iter'] == 0:
            #     val_output_imgs = trainer.sample(
            #         Variable(val_display_images['A'].cuda()),
            #         Variable(val_display_images['B'].cuda()))
            #
            #     for key, imgs in val_output_imgs.items():
            #         key = key.replace('/', '_')
            #         write_images(imgs, config['display_size'], '{}/{}_{:08}.jpg'.format(output_subfolders['images'], key, iterations+1))
            #
            #     logger.info('Saved images to: {}'.format(output_subfolders['images']))

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(output_subfolders['models'], iterations)

            iterations += 1
            if iterations >= max_iter:
                return