コード例 #1
0
def main():
    expdir = 'results/pgw/%s/%s/layer%d' % (ganname, modelname, layernum)

    if ganname == 'proggan':
        model = setting.load_proggan(modelname).cuda()
        zds = zdataset.z_dataset_for_model(model, size=1000)
        writer = ganrewrite.ProgressiveGanRewriter
    elif ganname == 'stylegan':
        model = load_seq_stylegan(modelname, mconv='seq')
        zds = zdataset.z_dataset_for_model(model, size=1000)
        writer = ganrewrite.SeqStyleGanRewriter

    model.eval()
    gw = writer(model, zds, layernum, cachedir=expdir)

    images = []
    with torch.no_grad():
        for _ in tqdm(range(N//batch_size + 1)):
            z = zdataset.z_sample_for_model(model, size=batch_size, seed=len(images)).cuda()
            samples = gw.sample_image_patch(z, crop_size)
            samples = [s.data.cpu() for s in samples]
            images.extend(samples)
        images = torch.stack(images[:N], dim=0)
    
    gt_fid = 0
    fake_fid = compute_fid(images, f'{modelname}_cropped_{images.size(2)}_{ganname}')
    save_image(images[:32] * 0.5 + 0.5, f'patches_{layernum}_{ganname}_{modelname}_{crop_size}.png')

    return fake_fid, gt_fid, images.size(2)
コード例 #2
0
def load_dataset(args, model=None):
    '''Loads an input dataset for testing.'''
    from torchvision import transforms
    if args.model == 'progan':
        dataset = zdataset.z_dataset_for_model(model, size=10000, seed=1)
        return dataset
    elif args.dataset in ['places']:
        crop_size = 227 if args.model == 'alexnet' else 224
        return setting.load_dataset(args.dataset, split='val', full=True,
                crop_size=crop_size, download=True)
    assert False
コード例 #3
0
def load_dataset(args, model=None):
    '''Loads an input dataset for testing.'''
    from torchvision import transforms
    print(args.dataset)
    if args.model == 'progan':
        dataset = zdataset.z_dataset_for_model(model, size=10000, seed=1)
        return dataset
    elif args.dataset in ['places', 'pacs-p', 'pacs-a', 'pacs-c', 'pacs-s']:
        if args.model == 'alexnet':
            crop_size = 227
        elif args.model == 'vgg16':
            crop_size = 224
        else:
            crop_size = 222 # for resnet18        
        return setting.load_dataset(args.dataset, split='val', full=True,
                crop_size=crop_size, download=True)
    assert False
コード例 #4
0
def tally_generated_objects(model, size=10000):
    zds = zdataset.z_dataset_for_model(model, size)
    loader = DataLoader(zds, batch_size=10, pin_memory=True)
    upp = segmenter.UnifiedParsingSegmenter()
    labelnames, catnames = upp.get_label_and_category_names()
    result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
    batch_result = torch.zeros(loader.batch_size,
                               NUM_OBJECTS,
                               dtype=torch.float).cuda()
    with torch.no_grad():
        batch_index = 0
        for [zbatch] in pbar(loader):
            img = model(zbatch.cuda())
            seg_result = upp.segment_batch(img)
            for i in range(len(zbatch)):
                batch_result[i] = (seg_result[i, 0].view(-1).bincount(
                    minlength=NUM_OBJECTS).float() /
                                   (seg_result.shape[2] * seg_result.shape[3]))
            result[batch_index:batch_index +
                   len(zbatch)] = (batch_result.cpu().numpy())
            batch_index += len(zbatch)
    return result
コード例 #5
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p
    def intpair(arg):
        p = arg.split(',')
        if len(p) == 1:
            p = p + p
        return tuple(int(v) for v in p)

    parser = argparse.ArgumentParser(description='Net dissect utility',
            prog='python -m netdissect',
            epilog=textwrap.dedent(help_epilog),
            formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model', type=str, default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--unstrict', action='store_true', default=False,
                        help='ignore unexpected pth parameters')
    parser.add_argument('--submodule', type=str, default=None,
                        help='submodule to load from pthfile')
    parser.add_argument('--outdir', type=str, default='dissect',
                        help='directory for dissection output')
    parser.add_argument('--layers', type=strpair, nargs='+',
                        help='space-separated list of layer names to dissect' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--segments', type=str, default='dataset/broden',
                        help='directory containing segmentation dataset')
    parser.add_argument('--segmenter', type=str, default=None,
                        help='constructor for asegmenter class')
    parser.add_argument('--download', action='store_true', default=False,
                        help='downloads Broden dataset if needed')
    parser.add_argument('--imagedir', type=str, default=None,
                        help='directory containing image-only dataset')
    parser.add_argument('--imgsize', type=intpair, default=(227, 227),
                        help='input image size to use')
    parser.add_argument('--netname', type=str, default=None,
                        help='name for network in generated reports')
    parser.add_argument('--meta', type=str, nargs='+',
                        help='json files of metadata to add to report')
    parser.add_argument('--merge', type=str,
                        help='json file of unit data to merge in report')
    parser.add_argument('--examples', type=int, default=20,
                        help='number of image examples per unit')
    parser.add_argument('--size', type=int, default=10000,
                        help='dataset subset size to use')
    parser.add_argument('--batch_size', type=int, default=100,
                        help='batch size for forward pass')
    parser.add_argument('--num_workers', type=int, default=24,
                        help='number of DataLoader workers')
    parser.add_argument('--quantile_threshold', type=strfloat, default=None,
                        choices=[FloatRange(0.0, 1.0), 'iqr'],
                        help='quantile to use for masks')
    parser.add_argument('--no-labels', action='store_true', default=False,
                        help='disables labeling of units')
    parser.add_argument('--maxiou', action='store_true', default=False,
                        help='enables maxiou calculation')
    parser.add_argument('--covariance', action='store_true', default=False,
                        help='enables covariance calculation')
    parser.add_argument('--rank_all_labels', action='store_true', default=False,
                        help='include low-information labels in rankings')
    parser.add_argument('--no-images', action='store_true', default=False,
                        help='disables generation of unit images')
    parser.add_argument('--no-report', action='store_true', default=False,
                        help='disables generation report summary')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA usage')
    parser.add_argument('--gen', action='store_true', default=False,
                        help='test a generator model (e.g., a GAN)')
    parser.add_argument('--gan', action='store_true', default=False,
                        help='synonym for --gen')
    parser.add_argument('--perturbation', default=None,
                        help='filename of perturbation attack to apply')
    parser.add_argument('--add_scale_offset', action='store_true', default=None,
                        help='offsets masks according to stride and padding')
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='silences console output')
    parser.add_argument('--p2p', action='store_true', default=False,
                        help='for running pix2pix (input segments)')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    args.images = not args.no_images
    args.report = not args.no_report
    args.labels = not args.no_labels
    if args.gan:
        args.gen = args.gan

    # Set up console output
    verbose_progress(not args.quiet)

    # Exit right away if job is already done or being done.
    if args.outdir is not None:
        exit_if_job_done(args.outdir)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Special case: download flag without model to test.
    if args.model is None and args.download:
        from netdissect.broden import ensure_broden_downloaded
        for resolution in [224, 227, 384]:
            ensure_broden_downloaded(args.segments, resolution, 1)
        from netdissect.segmenter import ensure_upp_segmenter_downloaded
        ensure_upp_segmenter_downloaded('dataset/segmodel')
        sys.exit(0)

    # Help if broden is not present
    if not args.gen and not args.imagedir and not os.path.isdir(args.segments):
        print_progress('Segmentation dataset not found at %s.' % args.segments)
        print_progress('Specify dataset directory using --segments [DIR]')
        print_progress('To download Broden, run: netdissect --download')
        sys.exit(1)

    # Default segmenter class
    if args.gen and args.segmenter is None:
        if args.p2p:
            args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" +
                              "segsizes=[256], segdiv=None)")
        else:
            args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" +
                              "segsizes=[256], segdiv='quad')")

    # Default threshold
    if args.quantile_threshold is None:
        if args.gen:
            args.quantile_threshold = 'iqr'
        else:
            args.quantile_threshold = 0.005

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Construct the network with specified layers instrumented
    if args.model is None:
        print_progress('No model specified')
        sys.exit(1)
    model = create_instrumented_model(args)

    # Update any metadata from files, if any
    meta = getattr(model, 'meta', {})
    if args.meta:
        for mfilename in args.meta:
            with open(mfilename) as f:
                meta.update(json.load(f))

    # Load any merge data from files
    mergedata = None
    if args.merge:
        with open(args.merge) as f:
            mergedata = json.load(f)

    # Set up the output directory, verify write access
    if args.outdir is None:
        args.outdir = os.path.join('dissect', type(model).__name__)
        exit_if_job_done(args.outdir)
        print_progress('Writing output into %s.' % args.outdir)
    os.makedirs(args.outdir, exist_ok=True)
    train_dataset = None

    if not args.gen:
        # Load dataset for classifier case.
        # Load perturbation
        perturbation = numpy.load(args.perturbation
                ) if args.perturbation else None
        segrunner = None

        # Load broden dataset
        if args.imagedir is not None:
            dataset = try_to_load_images(args.imagedir, args.imgsize,
                    perturbation, args.size)
            segrunner = ImageOnlySegRunner(dataset)
        else:
            dataset = try_to_load_broden(args.segments, args.imgsize, 1,
                perturbation, args.download, args.size)
        if dataset is None:
            dataset = try_to_load_multiseg(args.segments, args.imgsize,
                    perturbation, args.size)
        if dataset is None:
            print_progress('No segmentation dataset found in %s',
                    args.segments)
            print_progress('use --download to download Broden.')
            sys.exit(1)
    else:
        if not args.p2p:
            # For segmenter case the dataset is just a random z
            dataset = z_dataset_for_model(model, args.size)
            train_dataset = z_dataset_for_model(model, args.size, seed=2)
        else:
            dataset = get_segments_dataset('dataset/Adissect')
        segrunner = GeneratorSegRunner(autoimport_eval(args.segmenter))

    # Run dissect
    dissect(args.outdir, model, dataset,
            train_dataset=train_dataset,
            segrunner=segrunner,
            examples_per_unit=args.examples,
            netname=args.netname,
            quantile_threshold=args.quantile_threshold,
            meta=meta,
            merge=mergedata,
            make_images=args.images,
            make_labels=args.labels,
            make_maxiou=args.maxiou,
            make_covariance=args.covariance,
            make_report=args.report,
            make_row_images=args.images,
            make_single_images=True,
            rank_all_labels=args.rank_all_labels,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            settings=vars(args))

    # Mark the directory so that it's not done again.
    mark_job_done(args.outdir)
コード例 #6
0
def main():
    parser = argparse.ArgumentParser(
        description='GAN output segmentation util')
    parser.add_argument('--model',
                        type=str,
                        default='netdissect.proggan.from_pth_file("' +
                        'models/karras/churchoutdoor_lsun.pth")',
                        help='constructor for the model to test')
    parser.add_argument('--outdir',
                        type=str,
                        default='images',
                        help='directory for image output')
    parser.add_argument('--size',
                        type=int,
                        default=100,
                        help='number of images to output')
    parser.add_argument('--seed', type=int, default=1, help='seed')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    #if len(sys.argv) == 1:
    #    parser.print_usage(sys.stderr)
    #    sys.exit(1)
    args = parser.parse_args()
    verbose_progress(not args.quiet)

    # Instantiate the model
    model = autoimport_eval(args.model)

    # Make the standard z
    z_dataset = z_dataset_for_model(model, size=args.size)

    # Make the segmenter
    segmenter = UnifiedParsingSegmenter()

    # Write out text labels
    labels, cats = segmenter.get_label_and_category_names()
    with open(os.path.join(args.outdir, 'labels.txt'), 'w') as f:
        for i, (label, cat) in enumerate(labels):
            f.write('%s %s\n' % (label, cat))

    # Move models to cuda
    model.cuda()

    batch_size = 10
    progress = default_progress()
    dirname = args.outdir

    with torch.no_grad():
        # Pass 2: now generate images
        z_loader = torch.utils.data.DataLoader(z_dataset,
                                               batch_size=batch_size,
                                               num_workers=2,
                                               pin_memory=True)
        for batch_num, [z
                        ] in enumerate(progress(z_loader,
                                                desc='Saving images')):
            z = z.cuda()
            start_index = batch_num * batch_size
            tensor_im = model(z)
            byte_im = ((tensor_im + 1) / 2 * 255).clamp(0, 255).byte().permute(
                0, 2, 3, 1).cpu()
            seg = segmenter.segment_batch(tensor_im)
            for i in range(len(tensor_im)):
                index = i + start_index
                filename = os.path.join(dirname, '%d_img.jpg' % index)
                Image.fromarray(byte_im[i].numpy()).save(filename,
                                                         optimize=True,
                                                         quality=100)
                filename = os.path.join(dirname, '%d_seg.mat' % index)
                savemat(filename, dict(seg=seg[i].cpu().numpy()))
                filename = os.path.join(dirname, '%d_seg.png' % index)
                Image.fromarray(
                    segment_visualization(seg[i].cpu().numpy(),
                                          tensor_im.shape[2:])).save(filename)
    srcdir = os.path.realpath(
        os.path.join(os.getcwd(), os.path.dirname(__file__)))
    shutil.copy(os.path.join(srcdir, 'lightbox.html'),
                os.path.join(dirname, '+lightbox.html'))