예제 #1
0
파일: test_batch.py 프로젝트: zergey/MUNIT
def main(argv):
    (opts, args) = parser.parse_args(argv)
    torch.manual_seed(opts.seed)
    torch.cuda.manual_seed(opts.seed)
    if not os.path.exists(opts.output_folder):
        os.makedirs(opts.output_folder)

    # Load experiment setting
    config = get_config(opts.config)
    input_dim = config['new_size'] if opts.a2b else config['input_dim_b']
    style_dim = config['gen']['style_dim']

    # Setup model and data loader
    data_loader = get_data_loader_folder(opts.input_folder,
                                         1,
                                         False,
                                         input_dim == 1,
                                         crop=False)
    trainer = MUNIT_Trainer(config)
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
    trainer.cuda()
    trainer.eval()
    encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode  # encode function
    decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode  # decode function

    # Start testing
    style_fixed = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(),
                           volatile=True)
    for i, images in enumerate(data_loader):
        images = Variable(images.cuda(), volatile=True)
        content, _ = encode(images)
        style = style_fixed if opts.synchronized else Variable(
            torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True)
        for j in range(opts.num_style):
            s = style[j].unsqueeze(0)
            outputs = decode(content, s)
            outputs = (outputs + 1) / 2.
            path = os.path.join(opts.output_folder,
                                'input{:03d}_output{:03d}.jpg'.format(i, j))
            vutils.save_image(outputs.data, path, padding=0, normalize=True)
        if not opts.output_only:
            # also save input images
            vutils.save_image(images.data,
                              os.path.join(opts.output_folder,
                                           'input{:03d}.jpg'.format(i)),
                              padding=0,
                              normalize=True)
예제 #2
0
parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")

opts = parser.parse_args()

torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
data_loader = get_data_loader_folder(opts.input_folder,
                                     1,
                                     False,
                                     new_size=config['crop_image_height'],
                                     crop=False)

# config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")

last_model_name = get_model_list(opts.checkpoint, "gen")
state_dict = torch.load(last_model_name)
trainer.gen_a.load_state_dict(state_dict['a'])
예제 #3
0
파일: test.py 프로젝트: Cyril-JZ/v2v_trans
    device = torch.device(
        'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # Load experiment setting
    config.num_style = 1 if config.style != '' else config.num_style
    input_dim = config.input_dim_a if config.a2b else config.input_dim_b
    style_dim = config.gen_style_dim

    # Setup model and data loader
    image_names = ImageFolder(config.input_folder, return_paths=True)
    data_loader = get_data_loader_folder(config.input_folder, 1, False)

    model = V2VModel(config).to(device)
    state_dict = torch.load(config.checkpoint)
    model.gen_a.load_state_dict(state_dict['a'])
    model.gen_b.load_state_dict(state_dict['b'])
    model.eval()
    encode = model.gen_a.encode if config.a2b else model.gen_b.encode  # encode function
    decode = model.gen_b.decode if config.a2b else model.gen_a.decode  # decode function

    # Start testing
    style_fixed = Variable(torch.randn(config.num_style, style_dim, 1,
                                       1).to(device),
                           volatile=True)
    for i, (images, names) in enumerate(zip(data_loader, image_names)):
        print(names[1])
예제 #4
0
print(opts.dataset_path)
print(opts.output_path)
if not os.path.exists(opts.output_path):
    os.makedirs(opts.output_path)

torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']

# Setup model and data loader
image_names = ImageFolder(opts.dataset_path, transform=None, return_paths=True)
data_loader = get_data_loader_folder(opts.dataset_path, 1, False, new_size=config['new_size'], crop=False)
# data_loader = get_data_loader_folder(opts.input_path, 1, False, crop=False)

config['vgg_model_path'] = opts.weight
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")

try:
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
예제 #5
0
def runImageTransfer(preload_model, input_folder, user_key, a2b):
    #preload_model = [trainer, config, council_size, style_dim]
    trainer = preload_model[0]
    config = preload_model[1]
    council_size = preload_model[2]
    style_dim = preload_model[3]

    output_path = 'static'
    seed = 1
    num_of_images_to_test = 100
    
    # Setup model and data loader
    image_names = ImageFolder(input_folder, transform=None, return_paths=True)
    if not 'new_size_a' in config.keys():
        config['new_size_a'] = config['new_size']
    is_data_A = a2b
    data_loader = get_data_loader_folder(input_folder, 1, False,\
                                        new_size=config['new_size_a'] if 'new_size_a' in config.keys() else config['new_size'],\
                                        crop=False, config=config, is_data_A=is_data_A)
                                        
    encode_s = []
    decode_s = []
    if a2b:
        for i in range(council_size):
            encode_s.append(trainer.gen_a2b_s[i].encode)  # encode function
            decode_s.append(trainer.gen_a2b_s[i].decode)  # decode function
    else:
        for i in range(council_size):
            encode_s.append(trainer.gen_b2a_s[i].encode)  # encode function
            decode_s.append(trainer.gen_b2a_s[i].decode)  # decode function
    
    # creat testing images
    file_list= [] 
    seed = 1
    curr_image_num = -1
    
    for i, (images, names) in tqdm(enumerate(zip(data_loader, image_names)), total=num_of_images_to_test):

        if curr_image_num == num_of_images_to_test:
            break
        curr_image_num += 1
        k = np.random.randint(council_size)
        style_fixed = Variable(torch.randn(10, style_dim, 1, 1).cuda(), volatile=True)
        print(names[1])
        images = Variable(images.cuda(), volatile=True)

        content, _ = encode_s[k](images)
        seed += 1
        torch.random.manual_seed(seed)
        style = Variable(torch.randn(10, style_dim, 1, 1).cuda(), volatile=True)
        
        for j in range(10):
            s = style[j].unsqueeze(0)
            outputs = decode_s[k](content, s, images)
            basename = os.path.basename(names[1])
            output_folder = os.path.join(output_path, 'img') #output_folder = static/img
                
            path_all_in_one = os.path.join(output_folder, user_key , '_out_' + str(curr_image_num) + '_' + str(j) + '.jpg')
            file_list.append(path_all_in_one)
            do_all_in_one = True
            if do_all_in_one:
                if not os.path.exists(os.path.dirname(path_all_in_one)):
                    os.makedirs(os.path.dirname(path_all_in_one))
            vutils.save_image(outputs.data, path_all_in_one, padding=0, normalize=True)
    return file_list
예제 #6
0
    for param in inception.parameters():
        param.requires_grad = False
    inception_up = nn.Upsample(size=(299, 299), mode='bilinear')

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
if 'new_size' in config:
    new_size = config['new_size']
else:
    if opts.a2b==1:
        new_size = config['new_size_a']
    else:
        new_size = config['new_size_b']


data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=new_size,num_workers=1, crop=False)
config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")

try:
    state_dict = torch.load(opts.checkpoint)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
except:
    state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), opts.trainer)
예제 #7
0
trainer.eval()

# decode function
content_encode = trainer.gen_b.enc_content
style_encode = trainer.gen_b.enc_style
decode = trainer.gen_b.decode

new_size = config['new_size']

# Dataset loader
batch_size = 4
test_loader = get_data_loader_folder(opts.input,
                                     batch_size,
                                     False,
                                     new_size,
                                     config['crop_image_height'],
                                     config['crop_image_width'],
                                     config['num_workers'],
                                     True,
                                     return_path=True)

if not opts.use_avgsc:
    with torch.no_grad():
        transform = transforms.Compose([
            transforms.Resize(new_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        style_image = transform(Image.open(opts.style).convert(
            'RGB')).unsqueeze(0).cuda() if opts.style != '' else None
        style = style_encode(style_image)
예제 #8
0
if not os.path.exists(opts.output_folder):
    os.makedirs(opts.output_folder)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']
new_size = config['new_size']
crop_image_height = config['crop_image_height']
crop_image_width = config['crop_image_width']

# Setup model and data loader

data_loader_a = get_data_loader_folder(opts.A,
                                       1,
                                       False,
                                       new_size=new_size,
                                       height=crop_image_height,
                                       width=crop_image_width,
                                       crop=True)
data_loader_b = get_data_loader_folder(opts.B,
                                       1,
                                       False,
                                       new_size=new_size,
                                       height=crop_image_height,
                                       width=crop_image_width,
                                       crop=True)
imagea_names = ImageFolder(opts.A, transform=None, return_paths=True)
imageb_names = ImageFolder(opts.B, transform=None, return_paths=True)

config['vgg_model_path'] = opts.output_path
if opts.trainer == 'ERGAN':
예제 #9
0
파일: test_batch.py 프로젝트: phonx/MUNIT
parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight")
parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")

opts = parser.parse_args()


torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=config['new_size_a'], crop=False)

config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")


state_dict = torch.load(opts.checkpoint)
trainer.gen_a.load_state_dict(state_dict['a'])
trainer.gen_b.load_state_dict(state_dict['b'])
trainer.cuda()
예제 #10
0
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']
council_size = config['council']['council_size']

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
if not 'new_size_a' in config.keys():
    config['new_size_a'] = config['new_size']
is_data_A = opts.a2b
data_loader = get_data_loader_folder(opts.input_folder, 1, False,\
                                     new_size=config['new_size_a'] if 'new_size_a' in config.keys() else config['new_size'],\
                                     crop=False, config=config, is_data_A=is_data_A)

style_dim = config['gen']['style_dim']
trainer = Council_Trainer(config)
only_one = False
if 'gen_' in opts.checkpoint[-21:]:
    state_dict = torch.load(opts.checkpoint)
    try:
        if opts.a2b:
            trainer.gen_a2b_s[0].load_state_dict(state_dict['a2b'])
        else:
            trainer.gen_b2a_s[0].load_state_dict(state_dict['b2a'])
    except:
        print('opts.a2b should be set to ' + str(not opts.a2b) +
              ' , Or config file could be wrong')
예제 #11
0
파일: train.py 프로젝트: yzou2/DG-Net-PP
gpu_ids = []
for str_id in str_ids:
    gpu_ids.append(int(str_id))
num_gpu = len(gpu_ids)
if num_gpu > 1:
    raise Exception('Currently only single GPU training is supported!')

# Load experiment setting
config = get_config(opts.config)
set_seed(config['randseed'])
max_iter = config['max_iter']
display_size = config['display_size']
config['vgg_model_path'] = opts.output_path

# preparing sampling images
train_loader_a_sample = get_data_loader_folder(os.path.join(config['data_root_a'], 'train_all'), config['batch_size'], False,
                                        config['new_size'], config['crop_image_height'], config['crop_image_width'], config['num_workers'], False)
train_loader_b_sample = get_data_loader_folder(os.path.join(config['data_root_b'], 'train_all'), config['batch_size'], False,
                                             config['new_size'], config['crop_image_height'], config['crop_image_width'], config['num_workers'], False)

train_aba_rand = random.permutation(train_loader_a_sample.dataset.img_num)[0:display_size]
train_abb_rand = random.permutation(train_loader_b_sample.dataset.img_num)[0:display_size]
train_aab_rand = random.permutation(train_loader_a_sample.dataset.img_num)[0:display_size]
train_bbb_rand = random.permutation(train_loader_b_sample.dataset.img_num)[0:display_size]

train_display_images_aba = torch.stack([train_loader_a_sample.dataset[i][0] for i in train_aba_rand]).cuda()
train_display_images_abb = torch.stack([train_loader_b_sample.dataset[i][0] for i in train_abb_rand]).cuda()
train_display_images_aaa = torch.stack([train_loader_a_sample.dataset[i][0] for i in train_aba_rand]).cuda()
train_display_images_aab = torch.stack([train_loader_a_sample.dataset[i][0] for i in train_aab_rand]).cuda()
train_display_images_bba = torch.stack([train_loader_b_sample.dataset[i][0] for i in train_abb_rand]).cuda()
train_display_images_bbb = torch.stack([train_loader_b_sample.dataset[i][0] for i in train_bbb_rand]).cuda()
예제 #12
0
torch.cuda.manual_seed(opts.seed)
if not os.path.exists(opts.output_folder):
    os.makedirs(opts.output_folder)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']
new_size = config['new_size']

# imagea_names = ImageFolder(opts.A, transform=None, return_paths=True)
# imageb_names = ImageFolder(opts.B, transform=None, return_paths=True)

data_loader_a = get_data_loader_folder(opts.A,
                                       1,
                                       False,
                                       new_size=new_size,
                                       height=224,
                                       width=224,
                                       crop=False)
data_loader_b = get_data_loader_folder(opts.B,
                                       1,
                                       False,
                                       new_size=new_size,
                                       height=224,
                                       width=224,
                                       crop=False)

# Setup model and data loader
config['vgg_model_path'] = opts.output_path
if opts.trainer == 'ERGAN':
    style_dim = config['gen']['style_dim']
예제 #13
0
# Load the inception networks if we need to compute IS or CIIS
if opts.compute_IS or opts.compute_IS:
    inception = load_inception(
        opts.inception_b) if opts.a2b else load_inception(opts.inception_a)
    inception.cuda()
    # freeze the inception models and set eval mode
    inception.eval()
    for param in inception.parameters():
        param.requires_grad = False
    inception_up = nn.Upsample(size=(299, 299), mode='bilinear')

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
data_loader = get_data_loader_folder(opts.input_folder,
                                     1,
                                     True,
                                     new_size=config['new_size'],
                                     crop=False)  # Shuffle False to True

config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
    from utils import get_config, get_data_loader_folder, pytorch03_to_pytorch04, load_inception
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
    from utils import get_config, get_data_loader_folder, pytorch03_to_pytorch04, load_inception
elif opts.trainer == 'CDUNIT':
    style_dim = config['gen']['style_dim']
    trainer = CDUNIT_Trainer(config)
    from cd_utils import get_config, get_data_loader_folder, pytorch03_to_pytorch04, load_inception
예제 #14
0
# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']

# Load the inception networks if we need to compute IS or CIIS
if opts.compute_IS or opts.compute_IS:
    inception = load_inception(opts.inception_b) if opts.a2b else load_inception(opts.inception_a)
    # freeze the inception models and set eval mode
    inception.eval()
    for param in inception.parameters():
        param.requires_grad = False
    inception_up = nn.Upsample(size=(299, 299), mode='bilinear')

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=config['new_size'], crop=False)

config['vgg_model_path'] = opts.output_path
if opts.trainer == 'aclgan':
    style_dim = config['gen']['style_dim']
    trainer = aclgan_Trainer(config)
else:
    sys.exit("Only support aclgan)

def focus_translation(x_fg, x_bg, x_focus):
    x_map = (x_focus+1)/2
    x_map = x_map.repeat(1, 3, 1, 1)
    return (torch.mul((x_fg+1)/2, x_map) + torch.mul((x_bg+1)/2, 1-x_map))*2-1

if opts.trainer == 'aclgan':
    try:
예제 #15
0
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
except:
    state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), 'MUNIT')
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])

trainer.to(device)
trainer.eval()
encode1, encode2 = trainer.gen_a.encode, trainer.gen_b.encode  # encode function
decode1, decode2 = trainer.gen_a.decode, trainer.gen_b.decode  # decode function

loaderA = get_data_loader_folder(opts.input_folderA,
                                 1,
                                 False,
                                 new_size=config['new_size'],
                                 crop=True,
                                 height=config['new_size'],
                                 width=config['new_size'])
loaderB = get_data_loader_folder(opts.input_folderB,
                                 1,
                                 False,
                                 new_size=config['new_size'],
                                 crop=True,
                                 height=config['new_size'],
                                 width=config['new_size'])

content1, style1 = [], []
image1 = []
num_input1 = 0
for data in loaderA:
예제 #16
0
# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a']

# Setup data directories
folder_A = os.path.join(opts.input_folder, 'testA')
folder_B = os.path.join(opts.input_folder, 'testB')

# Setup data loaders
names_A = ImageFolder(folder_A, transform=None, return_paths=True)
names_B = ImageFolder(folder_B, transform=None, return_paths=True)

data_loader_A = get_data_loader_folder(folder_A,
                                       1,
                                       False,
                                       new_size=config['new_size'],
                                       crop=False)
data_loader_B = get_data_loader_folder(folder_B,
                                       1,
                                       False,
                                       new_size=config['new_size'],
                                       crop=False)

# Setup model
config['vgg_model_path'] = opts.output_path

trainer = UNIT_Trainer(config)

state_dict = torch.load(opts.checkpoint)
trainer.gen_a.load_state_dict(state_dict['a'])
예제 #17
0
opts = parser.parse_args()

torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']

# Setup model and data loader
image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
#data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=72, crop=False)
data_loader = get_data_loader_folder(opts.input_folder,
                                     1,
                                     False,
                                     new_size=500,
                                     crop=False)
#data_loader = get_data_loader_folder(opts.input_folder, 1, False, None, crop=False)

config['vgg_model_path'] = opts.output_path
if opts.trainer == 'MUNIT':
    style_dim = config['gen']['style_dim']
    trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")

state_dict = torch.load(opts.checkpoint,
                        map_location=lambda storage, loc: storage)