コード例 #1
0
def main(args):
    args = parse_arguments(args)
    weights_rgb = torch.load(args.model_path, map_location='cpu')
    network = vqmodel.VQ_CVAE(128, k=args.k, kl=args.kl, in_chns=3,
                              cos_distance=args.cos_dis)
    network.load_state_dict(weights_rgb)
    if args.exclude > 0:
        which_vec = [args.exclude - 1]
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    elif args.exclude < 0:
        which_vec = [*range(8)]
        which_vec.remove(abs(args.exclude) - 1)
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    network.cuda()
    network.eval()

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    intransform_funs = []
    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)
    manipulation_func = []
    args.suffix = ''
    if args.manipulation is not None:
        man_func = None
        parameters = None
        if args.manipulation[0] == 'contrast':
            man_func = imutils.adjust_contrast
            parameters = {'amount': float(args.manipulation[1])}
        elif args.manipulation[0] == 'gamma':
            man_func = imutils.adjust_gamma
            parameters = {'amount': float(args.manipulation[1])}
        elif args.manipulation[0] == 'luminance':
            man_func = imutils.reduce_lightness
            parameters = {'amount': float(args.manipulation[1])}
        elif args.manipulation[0] == 'chromaticity':
            man_func = imutils.reduce_chromaticity
            parameters = {'amount': float(args.manipulation[1])}
        elif args.manipulation[0] == 'red_green':
            man_func = imutils.reduce_red_green
            parameters = {'amount': float(args.manipulation[1])}
        elif args.manipulation[0] == 'yellow_blue':
            man_func = imutils.reduce_yellow_blue
            parameters = {'amount': float(args.manipulation[1])}
        elif args.manipulation[0] == 'illuminant':
            man_func = imutils.adjust_illuminant
            parameters = {'illuminant': [
                float(args.manipulation[1]),
                float(args.manipulation[2]),
                float(args.manipulation[3])
            ]}
        if man_func is None:
            sys.exit('Unsupported function %s' % args.manipulation[0])
        args.suffix = '_' + args.manipulation[0] + '_' + ''.join(
            e for e in args.manipulation[1:]
        )
        manipulation_func.append(cv2_preprocessing.UniqueTransformation(
            man_func, **parameters
        ))
        intransform_funs.append(*manipulation_func)

    args.in_colour_space = args.colour_space[:3]
    args.out_colour_space = args.colour_space[4:7]

    if args.in_colour_space != ' rgb':
        intransform_funs.append(
            cv2_preprocessing.ColourSpaceTransformation(args.in_colour_space)
        )
    intransform = transforms.Compose(intransform_funs)
    transform_funcs = transforms.Compose([
        # cv2_transforms.Resize(256), cv2_transforms.CenterCrop(224),
        cv2_transforms.ToTensor(),
        cv2_transforms.Normalize(mean, std)
    ])

    if args.dataset == 'imagenet':
        test_loader = torch.utils.data.DataLoader(
            data_loaders.ImageFolder(
                root=args.validation_dir,
                intransform=intransform,
                outtransform=None,
                transform=transform_funcs
            ),
            batch_size=args.batch_size, shuffle=False
        )
    elif args.dataset == 'celeba':
        test_loader = torch.utils.data.DataLoader(
            data_loaders.CelebA(
                root=args.validation_dir,
                intransform=intransform,
                outtransform=None,
                transform=transform_funcs,
                split='test'
            ),
            batch_size=args.batch_size, shuffle=False
        )
    else:
        test_loader = torch.utils.data.DataLoader(
            data_loaders.CategoryImages(
                root=args.validation_dir,
                # FIXME
                category=args.category,
                intransform=intransform,
                outtransform=None,
                transform=transform_funcs
            ),
            batch_size=args.batch_size, shuffle=False
        )
    export(test_loader, network, mean, std, args)
コード例 #2
0
def main(args):
    args = parse_arguments(args)
    weights_net = torch.load(args.model_path, map_location='cpu')

    args.in_colour_space = args.colour_space[:3]
    args.out_colour_space = args.colour_space[4:]

    args.outs_dict = dict()
    args.outs_dict[args.out_colour_space] = {'shape': [1, 1, 3]}
    from segmentation_models import unet
    network = unet.model.Unet(in_channels=3,
                              encoder_weights=None,
                              outs_dict=args.outs_dict,
                              classes=3)

    network.load_state_dict(weights_net)
    if args.exclude > 0:
        which_vec = [args.exclude - 1]
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    elif args.exclude < 0:
        which_vec = [*range(8)]
        which_vec.remove(abs(args.exclude) - 1)
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    network.cuda()
    network.eval()

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)
    transform_funcs = transforms.Compose([
        cv2_transforms.Resize(args.target_size + 32),
        cv2_transforms.CenterCrop(args.target_size),
        cv2_transforms.ToTensor(),
        cv2_transforms.Normalize(mean, std)
    ])

    intransform_funs = []
    if args.in_colour_space != ' rgb':
        intransform_funs.append(
            cv2_preprocessing.ColourSpaceTransformation(args.in_colour_space))
    if args.corr_noise:
        parameters = dict()
        parameters['function'] = imutils.s_p_noise
        parameters['kwargs'] = {'amount': args.corr_noise, 'eq_chns': True}
        intransform_funs.append(
            cv2_preprocessing.PredictionTransformation(parameters))
    intransform = transforms.Compose(intransform_funs)

    if args.dataset == 'imagenet':
        test_loader = torch.utils.data.DataLoader(data_loaders.ImageFolder(
            root=args.validation_dir,
            intransform=intransform,
            outtransform=None,
            transform=transform_funcs),
                                                  batch_size=args.batch_size,
                                                  shuffle=False)
    elif args.dataset == 'celeba':
        test_loader = torch.utils.data.DataLoader(data_loaders.CelebA(
            root=args.validation_dir,
            intransform=intransform,
            outtransform=None,
            transform=transform_funcs,
            split='test'),
                                                  batch_size=args.batch_size,
                                                  shuffle=False)
    else:
        test_loader = torch.utils.data.DataLoader(
            data_loaders.CategoryImages(
                root=args.validation_dir,
                # FIXME
                category=args.category,
                intransform=intransform,
                outtransform=None,
                transform=transform_funcs),
            batch_size=args.batch_size,
            shuffle=False)
    export(test_loader, network, mean, std, args)
コード例 #3
0
def main(args):
    args = parse_arguments(args)
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.random_seed is not None:
        os.environ['PYTHONHASHSEED'] = str(args.random_seed)
        torch.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)
        torch.cuda.manual_seed(args.random_seed)
        np.random.seed(args.random_seed)
        random.seed(args.random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    if args.cuda:
        args.gpus = [int(i) for i in args.gpus.split(',')]
        torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True

    args.mean = (0.5, 0.5, 0.5)
    args.std = (0.5, 0.5, 0.5)
    if 'labhue' in args.colour_space:
        args.mean = (0.5, 0.5, 0.5, 0.5)
        args.std = (0.5, 0.5, 0.5, 0.5)
    normalise = transforms.Normalize(args.mean, args.std)

    lr = args.lr or default_hyperparams[args.dataset]['lr']
    k = args.k or default_hyperparams[args.dataset]['k']
    hidden = args.hidden or default_hyperparams[args.dataset]['hidden']
    num_channels = args.num_channels or dataset_n_channels[args.dataset]
    target_size = args.target_size or dataset_target_size[args.dataset]

    dataset_transforms = {
        'custom': transforms.Compose(
            [transforms.Resize(256), transforms.CenterCrop(224),
             transforms.ToTensor(), normalise]),
        'coco': transforms.Compose([normalise]),
        'voc': transforms.Compose(
            [cv2_transforms.RandomResizedCropSegmentation(target_size,
                                                          scale=(0.50, 1.0)),
             cv2_transforms.ToTensorSegmentation(),
             cv2_transforms.NormalizeSegmentation(args.mean, args.std)]),
        'bsds': transforms.Compose(
            [cv2_transforms.RandomResizedCropSegmentation(target_size,
                                                          scale=(0.50, 1.0)),
             cv2_transforms.ToTensorSegmentation(),
             cv2_transforms.NormalizeSegmentation(args.mean, args.std)]),
        'imagenet': transforms.Compose(
            [cv2_transforms.Resize(target_size + 32),
             cv2_transforms.CenterCrop(target_size),
             cv2_transforms.ToTensor(),
             cv2_transforms.Normalize(args.mean, args.std)]),
        'celeba': transforms.Compose(
            [cv2_transforms.Resize(target_size + 32),
             cv2_transforms.CenterCrop(target_size),
             cv2_transforms.ToTensor(),
             cv2_transforms.Normalize(args.mean, args.std)]),
        'cifar10': transforms.Compose(
            [transforms.ToTensor(), normalise]),
        'mnist': transforms.ToTensor()
    }

    save_path = vae_util.setup_logging_from_args(args)
    writer = SummaryWriter(save_path)

    in_colour_space = args.colour_space[:3]
    out_colour_space = args.colour_space[4:]
    args.colour_space = out_colour_space

    if args.model == 'wavenet':
        # model = wavenet_vae.wavenet_bottleneck(
        #     latent_dim=k, in_channels=num_channels
        # )
        task = None
        out_chns = 3
        if 'voc' in args.dataset:
            task = 'segmentation'
            out_chns = 21
        from torchvision.models import resnet
        backbone = resnet.__dict__['resnet50'](
            pretrained=True,
            replace_stride_with_dilation=[False, True, True]
        )
        from torchvision.models._utils import IntermediateLayerGetter
        return_layers = {'layer4': 'out'}
        resnet = IntermediateLayerGetter(
            backbone, return_layers=return_layers
        )
        model = vae_model.ResNet_VQ_CVAE(
            hidden, k=k, resnet=resnet, num_channels=num_channels,
            colour_space=args.colour_space, task=task,
            out_chns=out_chns
        )
    elif args.model == 'vae':
        model = vanilla_vae.VanillaVAE(latent_dim=args.k, in_channels=3)
    else:
        task = None
        out_chns = 3
        if 'voc' in args.dataset:
            task = 'segmentation'
            out_chns = 21
        elif 'bsds' in args.dataset:
            task = 'segmentation'
            out_chns = 1
        elif args.colour_space == 'labhue':
            out_chns = 4
        backbone = None
        if args.backbone is not None:
            backbone = {
                'arch_name': args.backbone[0],
                'layer_name': args.backbone[1]
            }
            if len(args.backbone) > 2:
                backbone['weights_path'] = args.backbone[2]
            models[args.dataset][args.model] = vae_model.Backbone_VQ_VAE
        model = models[args.dataset][args.model](
            hidden, k=k, kl=args.kl, num_channels=num_channels,
            colour_space=args.colour_space, task=task,
            out_chns=out_chns, cos_distance=args.cos_dis,
            use_decor_loss=args.decor, backbone=backbone
        )
    if args.cuda:
        model.cuda()

    if args.load_encoder is not None:
        params_to_optimize = [
            {'params': [p for p in model.decoder.parameters() if
                        p.requires_grad]},
            {'params': [p for p in model.fc.parameters() if
                        p.requires_grad]},
        ]
        optimizer = optim.Adam(params_to_optimize, lr=lr)
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, int(args.epochs / 3), 0.5
    )

    if args.resume is not None:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'])
        model.cuda()
        args.start_epoch = checkpoint['epoch'] + 1
        scheduler.load_state_dict(checkpoint['scheduler'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    elif args.fine_tune is not None:
        weights = torch.load(args.fine_tune, map_location='cpu')
        model.load_state_dict(weights, strict=False)
        model.cuda()
    elif args.load_encoder is not None:
        weights = torch.load(args.load_encoder, map_location='cpu')
        weights = removekey(weights, 'decode')
        model.load_state_dict(weights, strict=False)
        model.cuda()

    intransform_funs = []
    if args.gamma is not None:
        kwargs = {'amount': args.gamma}
        augmentation_settings = [
            {'function': imutils.adjust_gamma, 'kwargs': kwargs}
        ]
        intransform_funs.append(
            cv2_preprocessing.RandomAugmentationTransformation(
                augmentation_settings, num_augmentations=1
            )
        )
    if args.mosaic_pattern is not None:
        intransform_funs.append(
            cv2_preprocessing.MosaicTransformation(args.mosaic_pattern)
        )
    if in_colour_space != 'rgb':
        intransform_funs.append(
            cv2_preprocessing.ColourSpaceTransformation(in_colour_space)
        )
    intransform = transforms.Compose(intransform_funs)
    outtransform_funs = []
    args.inv_func = None
    if args.colour_space is not None:
        outtransform_funs.append(
            cv2_preprocessing.ColourSpaceTransformation(args.colour_space)
        )
        if args.vis_rgb:
            args.inv_func = lambda x: generic_inv_fun(x, args.colour_space)
    outtransform = transforms.Compose(outtransform_funs)

    if args.data_dir is not None:
        args.train_dir = os.path.join(args.data_dir, 'train')
        args.validation_dir = os.path.join(args.data_dir, 'validation')
    else:
        args.train_dir = args.train_dir
        args.validation_dir = args.validation_dir
    kwargs = {'num_workers': args.workers,
              'pin_memory': True} if args.cuda else {}
    args.vis_func = vae_util.grid_save_reconstructed_images
    if args.colour_space == 'labhue':
        args.vis_func = vae_util.grid_save_reconstructed_labhue
    if args.dataset == 'coco':
        train_loader = panoptic_utils.get_coco_train(
            args.batch_size, args.opts, args.cfg_file
        )
        test_loader = panoptic_utils.get_coco_test(
            args.batch_size, args.opts, args.cfg_file
        )
    elif 'voc' in args.dataset:
        train_loader = torch.utils.data.DataLoader(
            data_loaders.VOCSegmentation(
                root=args.data_dir,
                image_set='train',
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            data_loaders.VOCSegmentation(
                root=args.data_dir,
                image_set='val',
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )
    elif args.category is not None:
        train_loader = torch.utils.data.DataLoader(
            data_loaders.CategoryImages(
                root=args.train_dir,
                category=args.category,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            data_loaders.CategoryImages(
                root=args.validation_dir,
                category=args.category,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )
    elif args.dataset == 'bsds':
        args.vis_func = vae_util.grid_save_reconstructed_bsds
        train_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.data_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.data_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )
    elif args.dataset == 'celeba':
        train_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.data_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                split='train',
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.data_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                split='test',
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.train_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.validation_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    for epoch in range(args.start_epoch, args.epochs):
        train_losses = train(
            epoch, model, train_loader, optimizer, args.cuda, args.log_interval,
            save_path, args
        )
        test_losses = test_net(epoch, model, test_loader, args.cuda, save_path,
                               args)
        for k in train_losses.keys():
            name = k.replace('_train', '')
            train_name = k
            test_name = k.replace('train', 'test')
            writer.add_scalars(
                name, {'train': train_losses[train_name],
                       'test': test_losses[test_name]}, epoch
            )
        scheduler.step()
        vae_util.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'arch': {'k': args.k, 'hidden': args.hidden}
            },
            save_path
        )
コード例 #4
0
def main(args):
    args = parse_arguments(args)
    weights_rgb = torch.load(args.model_path, map_location='cpu')
    network = vqmodel.VQ_CVAE(128,
                              k=args.k,
                              kl=args.kl,
                              in_chns=3,
                              cos_distance=args.cos_dis)
    network.load_state_dict(weights_rgb)
    if args.exclude > 0:
        which_vec = [args.exclude - 1]
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    elif args.exclude < 0:
        which_vec = [*range(8)]
        which_vec.remove(abs(args.exclude) - 1)
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    network.cuda()
    network.eval()

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    args.in_colour_space = args.colour_space[:3]
    args.out_colour_space = args.colour_space[4:]

    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)
    transform_funcs = transforms.Compose([
        # cv2_transforms.Resize(256), cv2_transforms.CenterCrop(224),
        cv2_transforms.ToTensor(),
        cv2_transforms.Normalize(mean, std)
    ])

    intransform_funs = []
    if args.in_colour_space != ' rgb':
        intransform_funs.append(
            cv2_preprocessing.ColourSpaceTransformation(args.in_colour_space))
    intransform = transforms.Compose(intransform_funs)

    if args.dataset == 'imagenet':
        test_loader = torch.utils.data.DataLoader(data_loaders.ImageFolder(
            root=args.validation_dir,
            intransform=intransform,
            outtransform=None,
            transform=transform_funcs),
                                                  batch_size=args.batch_size,
                                                  shuffle=False)
    else:
        test_loader = torch.utils.data.DataLoader(
            data_loaders.CategoryImages(
                root=args.validation_dir,
                # FIXME
                category=args.category,
                intransform=intransform,
                outtransform=None,
                transform=transform_funcs),
            batch_size=args.batch_size,
            shuffle=False)
    export(test_loader, network, args)