def compute_attribute_consistency(g, sub_g, n_sample, batch_size):
    attr_pred = models.get_pretrained('attribute-predictor').to(device)
    attr_pred.eval()

    n_batch = math.ceil(n_sample * 1. / batch_size / hvd.size())

    accs = 0
    mean_style = g.mean_style(10000)
    with torch.no_grad():
        for _ in tqdm(range(n_batch), disable=hvd.rank() != 0):
            noise = g.make_noise()

            latent = torch.randn(args.batch_size, 1, 512, device=device)
            kwargs = {
                'styles': latent,
                'truncation': 0.5,
                'truncation_style': mean_style,
                'noise': noise
            }
            img = g(**kwargs)[0].clamp(min=-1., max=1.)
            sub_img = sub_g(**kwargs)[0].clamp(min=-1., max=1.)
            img = adaptive_resize(img, 256)
            sub_img = adaptive_resize(sub_img, 256)

            attr = attr_pred(img).view(-1, 40, 2)
            sub_attr = attr_pred(sub_img).view(-1, 40, 2)

            attr = torch.argmax(attr, dim=2)
            sub_attr = torch.argmax(sub_attr, dim=2)
            this_acc = (attr == sub_attr).float().mean(0)
            accs = accs + this_acc / n_batch
    accs = hvd.allreduce(accs)
    return accs
Exemplo n.º 2
0
    parser.add_argument("--config",
                        type=str,
                        help='config name of the pretrained generator')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--n_sample', type=int, default=50000)
    parser.add_argument('--inception', type=str, default=None, required=True)

    parser.add_argument('--channel_ratio', type=float, default=None)
    parser.add_argument('--target_res', type=int, default=None)

    args = parser.parse_args()

    hvd.init()
    torch.cuda.set_device(hvd.local_rank())

    generator = models.get_pretrained('generator', args.config).to(device)
    generator.eval()

    # set sub-generator
    if args.channel_ratio:
        from models.dynamic_channel import set_uniform_channel_ratio, CHANNEL_CONFIGS

        assert args.channel_ratio in CHANNEL_CONFIGS
        set_uniform_channel_ratio(generator, args.channel_ratio)

    if args.target_res is not None:
        generator.target_res = args.target_res

    # compute the flops of the generator (is possible)
    if hvd.rank() == 0:
        try:
Exemplo n.º 3
0
                        help="batch size for the models (per gpu)")
    parser.add_argument('-j', '--workers', default=4, type=int)

    parser.add_argument("--lpips_net",
                        type=str,
                        default='vgg',
                        choices=['vgg', 'alex'])

    parser.add_argument('--calc_consist',
                        action='store_true',
                        default=False,
                        help='compute the consistency loss')
    args = parser.parse_args()

    # build models
    generator = models.get_pretrained('generator', args.config).to(device)
    generator.eval()

    encoder = models.get_pretrained('encoder', args.config).to(device)
    encoder.eval()

    # build test dataset
    val_transform = transforms.Compose([
        transforms.Resize(generator.resolution),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])
    test_dataset = NativeDataset(args.data_path, transform=val_transform)
    train_idx, val_idx, test_idx = get_celeba_hq_split()
    test_dataset = torch.utils.data.Subset(test_dataset, test_idx)
    parser.add_argument('--resolution', type=int, default=256)
    parser.add_argument('--batch_size',
                        default=64,
                        type=int,
                        help='batch size')
    parser.add_argument('--n_sample', type=int, default=50000)
    parser.add_argument('--flip', action='store_true')
    parser.add_argument('-j', '--workers', default=32, type=int)
    parser.add_argument('--save_name', default=None, type=str)
    parser.add_argument('path',
                        metavar='PATH',
                        help='path to dataset (image version)')

    args = parser.parse_args()

    inception = models.get_pretrained('inception').to(device).eval()
    inception = nn.DataParallel(inception)

    transform = transforms.Compose([
        transforms.Resize(args.resolution),
        transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = NativeDataset(args.path, transform=transform)
    loader = DataLoader(dataset,
                        batch_size=args.batch_size,
                        num_workers=args.workers,
                        shuffle=True)
Exemplo n.º 5
0
def get_style_attribute_pairs(
):  # this function is written with horovod to accelerate the extraction (by n_gpu times)
    import horovod.torch as hvd
    hvd.init()
    torch.cuda.set_device(hvd.local_rank())
    torch.manual_seed(hvd.rank() * 999 + 1)
    if hvd.rank() == 0:
        print(' * Extracting style-attribute pairs...')
    # build and load the pre-trained attribute predictor on CelebA-HQ
    predictor = models.get_pretrained('attribute-predictor').to(device)
    # build and load the pre-trained anycost generator
    generator = models.get_pretrained('generator', config).to(device)

    predictor.eval()
    generator.eval()

    # randomly generate images and feed them to the predictor
    # configs from https://github.com/genforce/interfacegan
    randomized_noise = False
    truncation_psi = 0.7
    batch_size = 16
    n_batch = 500000 // (batch_size * hvd.size())

    styles = []
    attributes = []

    mean_style = generator.mean_style(100000).view(1, 1, -1)
    assert space in ['w', 'w+', 'z']
    for _ in tqdm(range(n_batch), disable=hvd.rank() != 0):
        if space in ['w', 'z']:
            z = torch.randn(batch_size, 1, generator.style_dim, device=device)
        else:
            z = torch.randn(batch_size,
                            generator.n_style,
                            generator.style_dim,
                            device=device)
        images, w = generator(z,
                              return_styles=True,
                              truncation=truncation_psi,
                              truncation_style=mean_style,
                              input_is_style=False,
                              randomize_noise=randomized_noise)
        images = F.interpolate(images.clamp(-1, 1),
                               size=256,
                               mode='bilinear',
                               align_corners=True)
        attr = predictor(images)
        # move to cpu to save memory
        if space == 'w+':
            styles.append(w.to('cpu'))
        elif space == 'w':
            styles.append(w.mean(
                1, keepdim=True).to('cpu'))  # originally duplicated
        else:
            styles.append(z.to('cpu'))
        attributes.append(attr.to('cpu'))

    styles = torch.cat(styles, dim=0)
    attributes = torch.cat(attributes, dim=0)

    styles = hvd.allgather(styles, name='styles')
    attributes = hvd.allgather(attributes, name='attributes')
    if hvd.rank() == 0:
        print(styles.shape, attributes.shape)
        torch.save(attributes, 'attributes_{}.pt'.format(config))
        torch.save(styles, 'styles_{}.pt'.format(config))
    parser.add_argument("--n_sample",
                        type=int,
                        default=10000,
                        help="number of the samples for calculating PPL")
    parser.add_argument("--batch_size",
                        type=int,
                        default=16,
                        help="batch size for the models (per gpu)")

    args = parser.parse_args()

    hvd.init()
    torch.cuda.set_device(hvd.local_rank())

    generator = models.get_pretrained('generator',
                                      args.config).to(device).eval()

    sub_generator = models.get_pretrained('generator',
                                          args.config).to(device).eval()
    if args.channel_ratio:
        from models.dynamic_channel import set_uniform_channel_ratio
        set_uniform_channel_ratio(sub_generator, args.channel_ratio)

    if args.target_res is not None:
        sub_generator.target_res = args.target_res

    acc_list = compute_attribute_consistency(generator,
                                             sub_generator,
                                             n_sample=args.n_sample,
                                             batch_size=args.batch_size)
    acc_list = list(acc_list.to('cpu').numpy())