from data import CelebA_HQ test_dataset = CelebA_HQ(args.data_path, args.attr_path, args.image_list_path, args.img_size, 'test', args.attrs) os.makedirs(output_path, exist_ok=True) test_dataloader = data.DataLoader( test_dataset, batch_size=1, num_workers=args.num_workers, shuffle=False, drop_last=False ) if args.num_test is None: print('Testing images:', len(test_dataset)) else: print('Testing images:', min(len(test_dataset), args.num_test)) attgan = AttGAN(args) attgan.load(find_model(join('output', args.experiment_name, 'checkpoint'), args.load_epoch)) progressbar = Progressbar() attgan.eval() for idx, (img_a, att_a) in enumerate(test_dataloader): if args.num_test is not None and idx == args.num_test: break img_a = img_a.cuda() if args.gpu else img_a att_a = att_a.cuda() if args.gpu else att_a att_a = att_a.type(torch.float) att_b_list = [att_a] if args.by_levels: for i in range(args.n_attrs):
args.attrs) train_dataloader = data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, drop_last=True) valid_dataloader = data.DataLoader(valid_dataset, batch_size=args.n_samples, num_workers=args.num_workers, shuffle=False, drop_last=False) print('Training images:', len(train_dataset), '/', 'Validating images:', len(valid_dataset)) attgan = AttGAN(args) progressbar = Progressbar() writer = SummaryWriter(join('output', args.experiment_name, 'summary')) fixed_img_a, fixed_att_a = next(iter(valid_dataloader)) fixed_img_a = fixed_img_a.cuda() if args.gpu else fixed_img_a fixed_att_a = fixed_att_a.cuda() if args.gpu else fixed_att_a fixed_att_a = fixed_att_a.type(torch.float) sample_att_b_list = [fixed_att_a] for i in range(args.n_attrs): tmp = fixed_att_a.clone() tmp[:, i] = 1 - tmp[:, i] tmp = check_attribute_conflict(tmp, args.attrs[i], args.attrs) sample_att_b_list.append(tmp) it = 0 it_per_epoch = len(train_dataset) // args.batch_size