コード例 #1
0
        test_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'test', args.attrs)
    if args.data == 'CelebA-HQ':
        from data import CelebA_HQ
        test_dataset = CelebA_HQ(args.data_path, args.attr_path, args.image_list_path, args.img_size, 'test', args.attrs)
os.makedirs(output_path, exist_ok=True)
test_dataloader = data.DataLoader(
    test_dataset, batch_size=1, num_workers=args.num_workers,
    shuffle=False, drop_last=False
)
if args.num_test is None:
    print('Testing images:', len(test_dataset))
else:
    print('Testing images:', min(len(test_dataset), args.num_test))


attgan = AttGAN(args)
attgan.load(find_model(join('output', args.experiment_name, 'checkpoint'), args.load_epoch))
progressbar = Progressbar()

attgan.eval()


for idx, (img_a, att_a) in enumerate(test_dataloader):
    if args.num_test is not None and idx == args.num_test:
        break

    img_a = img_a.cuda() if args.gpu else img_a
    att_a = att_a.cuda() if args.gpu else att_a
    att_a = att_a.type(torch.float)

    att_b_list = [att_a]
コード例 #2
0
from os.path import join
from attgan import AttGAN


def parse(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_name',
                        dest='experiment_name',
                        type=str,
                        default='06-36AM on March 09, 2021')
    parser.add_argument('--load_epoch', dest='load_epoch', type=int, default=0)
    parser.add_argument('--gpu', dest='gpu', type=bool, default=False)
    return parser.parse_args(args)


args_ = parse()
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
    args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))

args.gpu = args_.gpu
args.experiment_name = args_.experiment_name
args.load_epoch = args_.load_epoch
args.betas = (args.beta1, args.beta2)

model = AttGAN(args)
model.load(
    os.path.join('output', args.experiment_name, 'checkpoint',
                 'weights.' + str(args.load_epoch) + '.pth'))
model.saveG_D(os.path.join('output', args.experiment_name, 'checkpoint',
                           'weights_unzip.{:d}.pth'.format(args.load_epoch)),
              flag='unzip')
コード例 #3
0
                              args.image_list_path, args.img_size, 'valid',
                              args.attrs)
train_dataloader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   shuffle=True,
                                   drop_last=True)
valid_dataloader = data.DataLoader(valid_dataset,
                                   batch_size=args.n_samples,
                                   num_workers=args.num_workers,
                                   shuffle=False,
                                   drop_last=False)
print('Training images:', len(train_dataset), '/', 'Validating images:',
      len(valid_dataset))

attgan = AttGAN(args)
progressbar = Progressbar()
writer = SummaryWriter(join('output', args.experiment_name, 'summary'))

fixed_img_a, fixed_att_a = next(iter(valid_dataloader))
fixed_img_a = fixed_img_a.cuda() if args.gpu else fixed_img_a
fixed_att_a = fixed_att_a.cuda() if args.gpu else fixed_att_a
fixed_att_a = fixed_att_a.type(torch.float)
sample_att_b_list = [fixed_att_a]
for i in range(args.n_attrs):
    tmp = fixed_att_a.clone()
    tmp[:, i] = 1 - tmp[:, i]
    tmp = check_attribute_conflict(tmp, args.attrs[i], args.attrs)
    sample_att_b_list.append(tmp)

it = 0
コード例 #4
0
ファイル: mytest.py プロジェクト: Mmhmmmmm/cafe_gan_pytorch
args.attr_path = args_.attr_path

test_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'mytest',
                      args.attrs)
test_dataloader = data.DataLoader(test_dataset,
                                  batch_size=1,
                                  num_workers=args.num_workers,
                                  shuffle=False,
                                  drop_last=False)

print('Testing images:', len(test_dataset))

output_path = join('output', args.experiment_name, 'attention_testing')
os.makedirs(output_path, exist_ok=True)

attgan = AttGAN(args)
attgan.load(r'weights_unzip.17.pth')
attgan.eval()

for idx, (img_real, att_org) in enumerate(test_dataloader):
    img_real = img_real.cuda() if args.gpu else img_real
    att_org = att_org.cuda() if args.gpu else att_org
    att_org = att_org.type(torch.float)
    _, mc, mw, mh = img_real.shape
    att_list = [att_org]
    img_unit = img_real.view(3, mw, mh)
    img_unit = ((img_unit * 0.5) + 0.5) * 255
    img_unit = np.uint8(img_unit)
    img_unit = img_unit[::-1, :, :].transpose(1, 2, 0)
    for i in range(args.n_attrs):
        tmp = att_org.clone()
コード例 #5
0
                        dest='load_epoch',
                        type=str,
                        default='latest')
    parser.add_argument('--gpu', action='store_true')
    return parser.parse_args(args)


args_ = parse()

with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
    args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))

args.load_epoch = args_.load_epoch
args.gpu = args_.gpu

attgan = AttGAN(args)
attgan.load(
    find_model(join('output', args.experiment_name, 'checkpoint'),
               args.load_epoch))
attgan.eval()

progressbar = Progressbar()

tf = transforms.Compose([
    transforms.CenterCrop(170),
    transforms.Resize(args.img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

att_list = open(args_.attr_path, 'r', encoding='utf-8').readlines()[0].split()
コード例 #6
0
                              args.image_list_path, args.img_size, 'valid',
                              args.attrs)
train_dataloader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   shuffle=True,
                                   drop_last=True)
valid_dataloader = data.DataLoader(valid_dataset,
                                   batch_size=args.n_samples,
                                   num_workers=args.num_workers,
                                   shuffle=False,
                                   drop_last=False)
print('Training images:', len(train_dataset), '/', 'Validating images:',
      len(valid_dataset))

attgan = AttGAN(args)
if args.is_resume:
    attgan.load(
        os.path.join('output', args.experiment_name, 'checkpoint',
                     'weights.' + str(args.load_epoch) + '.pth'))
    ea = event_accumulator.EventAccumulator(
        join('output', args.experiment_name, 'summary', args.event_name))
    ea.Reload()
    d_loss = ea.scalars.Items('D/d_loss')
    it = d_loss[-1][1]
else:
    it = 0
cudnn.benchmark = True
progressbar = Progressbar()
# writer = SummaryWriter(join('output', args.experiment_name, 'summary')) if not args.is_resume else \
# SummaryWriter(join('output', args.experiment_name, 'summary', args.writer_name))
コード例 #7
0
#from data import CelebA
#test_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'test', args.attrs)

os.makedirs(output_path, exist_ok=True)
test_dataloader = data.DataLoader(test_dataset,
                                  batch_size=1,
                                  num_workers=args.num_workers,
                                  shuffle=False,
                                  drop_last=False)
if args.num_test is None:
    print('Testing images:', len(test_dataset))
else:
    print('Testing images:', min(len(test_dataset), args.num_test))

attgan = AttGAN(args)
attgan.load(
    find_model(join('output', args.experiment_name, 'checkpoint'),
               args.load_epoch))
progressbar = Progressbar()

att_b_ = optimumRandom(isMale)
att_b_reverse = optimumRandom(not isMale)

att_b_ = allRandom()
att_b_reverse = att_b_ * -1

#att_b_ = spesificParams()
#att_b_ = singleAttArray(12, -1)

goodArr = makeGoodAttr(isMale)
コード例 #8
0
            )  # 建立生成图像的路径 output/128_shortcut1_inject1_none/sample_testing
# 每次处理一张图像
test_dataloader = data.DataLoader(test_dataset,
                                  batch_size=1,
                                  num_workers=args.num_workers,
                                  shuffle=False,
                                  drop_last=False)

# 如果没有限定处理多少张,那就把整个数据集都做迁移
if args.num_test is None:
    print('Testing images:', len(test_dataset))
else:
    print('Testing images:', min(len(test_dataset), args.num_test))

# 载入AttGAN模型
attgan = AttGAN(args)
# 载入指定节点
attgan.load(
    find_model(join('output', args.experiment_name, 'checkpoint'),
               args.load_epoch)
)  # 载入指定节点 output/128_shortcut1_inject1_none/checkpoint/
progressbar = Progressbar()
# 进行验证
attgan.eval()
# 对图片的大循环
for idx, (img_a, att_a) in enumerate(test_dataloader):
    '''
    idx:            图像的索引
    img_a:          图像
    att_a:          标签
    原始标签        label_a