예제 #1
0
def get_model_and_loader():
    ''' helper to return the model and the loader '''
    aux_transform = None
    if args.synthetic_upsample_size > 0 and args.task == "multi_image_folder":
        aux_transform = lambda x: F.interpolate(
            torchvision.transforms.ToTensor()(x).unsqueeze(0),
            size=(args.synthetic_upsample_size, args.synthetic_upsample_size),
            mode='bilinear',
            align_corners=True).squeeze(0)

    # resizer = torchvision.transforms.Resize(size=(args.synthetic_upsample_size,
    #                                               args.synthetic_upsample_size))
    loader = get_loader(
        args,
        transform=None,  #transform=[resizer],
        sequentially_merge_test=False,
        aux_transform=aux_transform,
        postfix="_large",
        **vars(args))

    # append the image shape to the config & build the VAE
    args.img_shp = loader.img_shp
    vae = VRNN(
        loader.img_shp,
        n_layers=2,  # XXX: hard coded
        #bidirectional=True,    # XXX: hard coded
        bidirectional=False,  # XXX: hard coded
        kwargs=vars(args))

    # build the Variational Saccading module
    # and lazy generate the non-constructed modules
    saccader = Saccader(vae, loader.output_size, kwargs=vars(args))
    lazy_generate_modules(saccader, loader.train_loader)

    # FP16-ize, cuda-ize and parallelize (if requested)
    saccader = saccader.fp16() if args.half is True else saccader
    saccader = saccader.cuda() if args.cuda is True else saccader
    saccader.parallel() if args.ngpu > 1 else saccader

    # build the grapher object (tensorboard or visdom)
    # and plot config json to visdom
    if args.visdom_url is not None:
        grapher = Grapher('visdom',
                          env=saccader.get_name(),
                          server=args.visdom_url,
                          port=args.visdom_port)
    else:
        grapher = Grapher('tensorboard', comment=saccader.get_name())

    grapher.add_text('config',
                     pprint.PrettyPrinter(indent=4).pformat(saccader.config),
                     0)

    # register_nan_checks(saccader)
    return [saccader, loader, grapher]
예제 #2
0
def get_model_and_loader():
    ''' helper to return the model and the loader '''
    aux_transform = None
    if args.synthetic_upsample_size > 0:  #and args.task == "multi_image_folder":
        to_pil = torchvision.transforms.ToPILImage()
        to_tensor = torchvision.transforms.ToTensor()
        resizer = torchvision.transforms.Resize(
            size=(args.synthetic_upsample_size, args.synthetic_upsample_size),
            interpolation=2)

        def extract_patches_2D(img, size):
            patch_H, patch_W = min(img.size(2),
                                   size[0]), min(img.size(3), size[1])
            patches_fold_H = img.unfold(2, patch_H, patch_H)
            if (img.size(2) % patch_H != 0):
                patches_fold_H = torch.cat(
                    (patches_fold_H, img[:, :, -patch_H:, ].permute(
                        0, 1, 3, 2).unsqueeze(2)),
                    dim=2)
                patches_fold_HW = patches_fold_H.unfold(3, patch_W, patch_W)

            if (img.size(3) % patch_W != 0):
                patches_fold_HW = torch.cat(
                    (patches_fold_HW,
                     patches_fold_H[:, :, :, -patch_W:, :].permute(
                         0, 1, 2, 4, 3).unsqueeze(3)),
                    dim=3)

                patches = patches_fold_HW.permute(0, 2, 3, 1, 4, 5).reshape(
                    -1, img.size(1), patch_H, patch_W)

            return patches

        def patch_extractor_lambda(crop):
            crop = crop.unsqueeze(0) if len(crop.shape) < 4 else crop
            return extract_patches_2D(crop, [224, 224])

        aux_transform = lambda x: patch_extractor_lambda(
            to_tensor(resizer(to_pil(to_tensor(x)))))

    loader = get_loader(args,
                        transform=None,
                        sequentially_merge_test=False,
                        aux_transform=aux_transform,
                        postfix="_large",
                        **vars(args))

    # append the image shape to the config & build the VAE
    args.img_shp = loader.img_shp
    model = MultiBatchModule(loader.output_size, checkpoint=args.checkpoint)

    # FP16-ize, cuda-ize and parallelize (if requested)
    model = model.half() if args.half is True else model
    model = model.cuda() if args.cuda is True else model
    model = nn.DataParallel(model) if args.ngpu > 1 else model

    # build the grapher object (tensorboard or visdom)
    # and plot config json to visdom
    if args.visdom_url is not None:
        grapher = Grapher('visdom',
                          env=get_name(),
                          server=args.visdom_url,
                          port=args.visdom_port)
    else:
        grapher = Grapher('tensorboard', comment=get_name())

    grapher.add_text('config',
                     pprint.PrettyPrinter(indent=4).pformat(vars(args)), 0)
    return [model, loader, grapher]