Пример #1
0
def train_set(db, target_size, mean, std, extra_transformation=None, **kwargs):
    if extra_transformation is None:
        extra_transformation = []
    if kwargs['train_params'] is None:
        shared_pre_transforms = [
            *extra_transformation,
            cv2_transforms.RandomHorizontalFlip(),
        ]
    else:
        shared_pre_transforms = [*extra_transformation]
    shared_post_transforms = _get_shared_post_transforms(mean, std)
    if db in NATURAL_DATASETS:
        # if train params are passed don't use any random processes
        if kwargs['train_params'] is None:
            scale = (0.08, 1.0)
            size_transform = cv2_transforms.RandomResizedCrop(target_size,
                                                              scale=scale)
            pre_transforms = [size_transform, *shared_pre_transforms]
        else:
            pre_transforms = [
                cv2_transforms.Resize(target_size),
                cv2_transforms.CenterCrop(target_size), *shared_pre_transforms
            ]
        post_transforms = [*shared_post_transforms]
        return _natural_dataset(db, 'train', pre_transforms, post_transforms,
                                **kwargs)
    elif db in ['gratings']:
        return _get_grating_dataset(shared_pre_transforms,
                                    shared_post_transforms, target_size,
                                    **kwargs)
    return None
Пример #2
0
def get_val_dataset(val_dir, target_size, preprocess):
    mean, std = preprocess
    normalise = cv2_transforms.Normalize(mean=mean, std=std)
    transform = torch_transforms.Compose([
        cv2_transforms.Resize(target_size),
        cv2_transforms.CenterCrop(target_size),
        cv2_transforms.ToTensor(),
        normalise,
    ])
    train_dataset = ImageFolder({'root': val_dir, 'transform': transform})
    return train_dataset
Пример #3
0
def prepare_transformations_train(dataset_name,
                                  colour_transformations,
                                  other_transformations,
                                  chns_transformation,
                                  normalize,
                                  target_size,
                                  random_labels=False):
    if 'cifar' in dataset_name or dataset_name in folder_dbs:
        flip_p = 0.5
        if random_labels:
            size_transform = cv2_transforms.Resize(target_size)
            flip_p = -1
        elif 'cifar' in dataset_name:
            size_transform = cv2_transforms.RandomCrop(target_size, padding=4)
        elif 'imagenet' in dataset_name or 'ecoset' in dataset_name:
            scale = (0.08, 1.0)
            size_transform = cv2_transforms.RandomResizedCrop(target_size,
                                                              scale=scale)
        else:
            scale = (0.50, 1.0)
            size_transform = cv2_transforms.RandomResizedCrop(target_size,
                                                              scale=scale)
        transformations = torch_transforms.Compose([
            size_transform,
            *colour_transformations,
            *other_transformations,
            cv2_transforms.RandomHorizontalFlip(p=flip_p),
            cv2_transforms.ToTensor(),
            *chns_transformation,
            normalize,
        ])
    elif 'wcs_lms' in dataset_name:
        # FIXME: colour transformation in lms is different from rgb or lab
        transformations = torch_transforms.Compose([
            *other_transformations,
            RandomHorizontalFlip(),
            Numpy2Tensor(),
            *chns_transformation,
            normalize,
        ])
    elif 'wcs_jpg' in dataset_name:
        transformations = torch_transforms.Compose([
            *colour_transformations,
            *other_transformations,
            cv2_transforms.RandomHorizontalFlip(),
            cv2_transforms.ToTensor(),
            *chns_transformation,
            normalize,
        ])
    else:
        sys.exit('Transformations for dataset %s is not supported.' %
                 dataset_name)
    return transformations
Пример #4
0
def prepare_transformations_test(dataset_name,
                                 colour_transformations,
                                 other_transformations,
                                 chns_transformation,
                                 normalize,
                                 target_size,
                                 task=None):
    if 'cifar' in dataset_name or dataset_name in folder_dbs:
        transformations = torch_transforms.Compose([
            cv2_transforms.Resize(target_size),
            cv2_transforms.CenterCrop(target_size),
            *colour_transformations,
            *other_transformations,
            cv2_transforms.ToTensor(),
            *chns_transformation,
            normalize,
        ])
    elif 'wcs_lms' in dataset_name:
        # FIXME: colour transformation in lms is different from rgb or lab
        transformations = torch_transforms.Compose([
            *other_transformations,
            Numpy2Tensor(),
            *chns_transformation,
            normalize,
        ])
    elif 'wcs_jpg' in dataset_name:
        transformations = torch_transforms.Compose([
            *colour_transformations,
            *other_transformations,
            cv2_transforms.ToTensor(),
            *chns_transformation,
            normalize,
        ])
    elif 'voc' in dataset_name or task == 'segmentation':
        transformations = []
    else:
        sys.exit('Transformations for dataset %s is not supported.' %
                 dataset_name)
    return transformations
Пример #5
0
def validation_set(db,
                   target_size,
                   mean,
                   std,
                   extra_transformation=None,
                   **kwargs):
    if extra_transformation is None:
        extra_transformation = []
    shared_pre_transforms = [*extra_transformation]
    shared_post_transforms = _get_shared_post_transforms(mean, std)
    if db in NATURAL_DATASETS:
        pre_transforms = [
            cv2_transforms.Resize(target_size),
            cv2_transforms.CenterCrop(target_size), *shared_pre_transforms
        ]
        post_transforms = [*shared_post_transforms]
        return _natural_dataset(db, 'validation', pre_transforms,
                                post_transforms, **kwargs)
    elif db in ['gratings']:
        return _get_grating_dataset(shared_pre_transforms,
                                    shared_post_transforms, target_size,
                                    **kwargs)
    return None
def main(args):
    args = parse_arguments(args)
    if args.random_seed < 0:
        os.environ['PYTHONHASHSEED'] = str(args.random_seed)
        torch.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)
        torch.cuda.manual_seed(args.random_seed)
        np.random.seed(args.random_seed)
        random.seed(args.random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    weights_rgb = torch.load(args.model_path, map_location='cpu')
    network = vqmodel.VQ_CVAE(128,
                              k=args.k,
                              kl=args.kl,
                              in_chns=3,
                              cos_distance=args.cos_dis)
    network.load_state_dict(weights_rgb)
    if args.exclude > 0:
        which_vec = [args.exclude - 1]
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    elif args.exclude < 0:
        which_vec = [*range(8)]
        which_vec.remove(abs(args.exclude) - 1)
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    network.cuda()
    network.eval()

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    args.in_colour_space = args.colour_space[:3]
    args.out_colour_space = args.colour_space[4:7]

    (imagenet_model, target_size) = model_utils.which_network(
        args.imagenet_model,
        'classification',
        num_classes=args.num_classes,
    )
    imagenet_model.cuda()
    imagenet_model.eval()

    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)

    if 'vggface2' in args.validation_dir:
        transform_funcs = transforms.Compose([
            cv2_transforms.Resize(256),
            cv2_transforms.CenterCrop(224),
            cv2_transforms.ToTensor(),
            cv2_transforms.Normalize(mean, std)
        ])

        imagenet_transformations = transforms.Compose([
            cv2_transforms.ToTensor(),
            cv2_transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
        ])
    else:
        transform_funcs = transforms.Compose([
            cv2_transforms.Resize(512),
            cv2_transforms.CenterCrop(512),
            cv2_transforms.ToTensor(),
            cv2_transforms.Normalize(mean, std)
        ])

        imagenet_transformations = transforms.Compose([
            cv2_transforms.Resize(256),
            cv2_transforms.CenterCrop(224),
            cv2_transforms.ToTensor(),
            cv2_transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
        ])

    intransform_funs = []
    if args.in_colour_space != ' rgb':
        intransform_funs.append(
            cv2_preprocessing.VisionTypeTransformation(None,
                                                       args.in_colour_space))
    if args.noise[0] is not None:
        value = float(args.noise[1])
        noise_name = args.noise[0]
        if noise_name == 'sp':
            noise_fun = imutils.s_p_noise
            kwargs = {'amount': value, 'seed': args.random_seed}
        elif noise_name == 'gaussian':
            noise_fun = imutils.gaussian_noise
            kwargs = {'amount': value, 'seed': args.random_seed}
        elif noise_name == 'speckle':
            noise_fun = imutils.speckle_noise
            kwargs = {'amount': value, 'seed': args.random_seed}
        elif noise_name == 'blur':
            noise_fun = imutils.gaussian_blur
            kwargs = {'sigmax': value, 'seed': args.random_seed}

        if noise_name != 'blur':
            kwargs['eq_chns'] = True
        intransform_funs.append(
            cv2_preprocessing.UniqueTransformation(noise_fun, **kwargs))
    intransform = transforms.Compose(intransform_funs)

    test_loader = torch.utils.data.DataLoader(ImageFolder(
        root=args.validation_dir,
        intransform=intransform,
        outtransform=None,
        transform=transform_funcs),
                                              batch_size=args.batch_size,
                                              shuffle=False)
    top1, top5, prediction_output = export(test_loader, network, mean, std,
                                           imagenet_model,
                                           imagenet_transformations, args)
    output_file = '%s/%s.csv' % (args.out_dir, args.colour_space)
    np.savetxt(output_file, prediction_output, delimiter=',', fmt='%i')
Пример #7
0
def main(args):
    args = parse_arguments(args)
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.random_seed is not None:
        os.environ['PYTHONHASHSEED'] = str(args.random_seed)
        torch.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)
        torch.cuda.manual_seed(args.random_seed)
        np.random.seed(args.random_seed)
        random.seed(args.random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    if args.cuda:
        args.gpus = [int(i) for i in args.gpus.split(',')]
        torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True

    args.mean = (0.5, 0.5, 0.5)
    args.std = (0.5, 0.5, 0.5)
    if 'labhue' in args.colour_space:
        args.mean = (0.5, 0.5, 0.5, 0.5)
        args.std = (0.5, 0.5, 0.5, 0.5)
    normalise = transforms.Normalize(args.mean, args.std)

    lr = args.lr or default_hyperparams[args.dataset]['lr']
    k = args.k or default_hyperparams[args.dataset]['k']
    hidden = args.hidden or default_hyperparams[args.dataset]['hidden']
    num_channels = args.num_channels or dataset_n_channels[args.dataset]
    target_size = args.target_size or dataset_target_size[args.dataset]

    dataset_transforms = {
        'custom': transforms.Compose(
            [transforms.Resize(256), transforms.CenterCrop(224),
             transforms.ToTensor(), normalise]),
        'coco': transforms.Compose([normalise]),
        'voc': transforms.Compose(
            [cv2_transforms.RandomResizedCropSegmentation(target_size,
                                                          scale=(0.50, 1.0)),
             cv2_transforms.ToTensorSegmentation(),
             cv2_transforms.NormalizeSegmentation(args.mean, args.std)]),
        'bsds': transforms.Compose(
            [cv2_transforms.RandomResizedCropSegmentation(target_size,
                                                          scale=(0.50, 1.0)),
             cv2_transforms.ToTensorSegmentation(),
             cv2_transforms.NormalizeSegmentation(args.mean, args.std)]),
        'imagenet': transforms.Compose(
            [cv2_transforms.Resize(target_size + 32),
             cv2_transforms.CenterCrop(target_size),
             cv2_transforms.ToTensor(),
             cv2_transforms.Normalize(args.mean, args.std)]),
        'celeba': transforms.Compose(
            [cv2_transforms.Resize(target_size + 32),
             cv2_transforms.CenterCrop(target_size),
             cv2_transforms.ToTensor(),
             cv2_transforms.Normalize(args.mean, args.std)]),
        'cifar10': transforms.Compose(
            [transforms.ToTensor(), normalise]),
        'mnist': transforms.ToTensor()
    }

    save_path = vae_util.setup_logging_from_args(args)
    writer = SummaryWriter(save_path)

    in_colour_space = args.colour_space[:3]
    out_colour_space = args.colour_space[4:]
    args.colour_space = out_colour_space

    if args.model == 'wavenet':
        # model = wavenet_vae.wavenet_bottleneck(
        #     latent_dim=k, in_channels=num_channels
        # )
        task = None
        out_chns = 3
        if 'voc' in args.dataset:
            task = 'segmentation'
            out_chns = 21
        from torchvision.models import resnet
        backbone = resnet.__dict__['resnet50'](
            pretrained=True,
            replace_stride_with_dilation=[False, True, True]
        )
        from torchvision.models._utils import IntermediateLayerGetter
        return_layers = {'layer4': 'out'}
        resnet = IntermediateLayerGetter(
            backbone, return_layers=return_layers
        )
        model = vae_model.ResNet_VQ_CVAE(
            hidden, k=k, resnet=resnet, num_channels=num_channels,
            colour_space=args.colour_space, task=task,
            out_chns=out_chns
        )
    elif args.model == 'vae':
        model = vanilla_vae.VanillaVAE(latent_dim=args.k, in_channels=3)
    else:
        task = None
        out_chns = 3
        if 'voc' in args.dataset:
            task = 'segmentation'
            out_chns = 21
        elif 'bsds' in args.dataset:
            task = 'segmentation'
            out_chns = 1
        elif args.colour_space == 'labhue':
            out_chns = 4
        backbone = None
        if args.backbone is not None:
            backbone = {
                'arch_name': args.backbone[0],
                'layer_name': args.backbone[1]
            }
            if len(args.backbone) > 2:
                backbone['weights_path'] = args.backbone[2]
            models[args.dataset][args.model] = vae_model.Backbone_VQ_VAE
        model = models[args.dataset][args.model](
            hidden, k=k, kl=args.kl, num_channels=num_channels,
            colour_space=args.colour_space, task=task,
            out_chns=out_chns, cos_distance=args.cos_dis,
            use_decor_loss=args.decor, backbone=backbone
        )
    if args.cuda:
        model.cuda()

    if args.load_encoder is not None:
        params_to_optimize = [
            {'params': [p for p in model.decoder.parameters() if
                        p.requires_grad]},
            {'params': [p for p in model.fc.parameters() if
                        p.requires_grad]},
        ]
        optimizer = optim.Adam(params_to_optimize, lr=lr)
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, int(args.epochs / 3), 0.5
    )

    if args.resume is not None:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'])
        model.cuda()
        args.start_epoch = checkpoint['epoch'] + 1
        scheduler.load_state_dict(checkpoint['scheduler'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    elif args.fine_tune is not None:
        weights = torch.load(args.fine_tune, map_location='cpu')
        model.load_state_dict(weights, strict=False)
        model.cuda()
    elif args.load_encoder is not None:
        weights = torch.load(args.load_encoder, map_location='cpu')
        weights = removekey(weights, 'decode')
        model.load_state_dict(weights, strict=False)
        model.cuda()

    intransform_funs = []
    if args.gamma is not None:
        kwargs = {'amount': args.gamma}
        augmentation_settings = [
            {'function': imutils.adjust_gamma, 'kwargs': kwargs}
        ]
        intransform_funs.append(
            cv2_preprocessing.RandomAugmentationTransformation(
                augmentation_settings, num_augmentations=1
            )
        )
    if args.mosaic_pattern is not None:
        intransform_funs.append(
            cv2_preprocessing.MosaicTransformation(args.mosaic_pattern)
        )
    if in_colour_space != 'rgb':
        intransform_funs.append(
            cv2_preprocessing.ColourSpaceTransformation(in_colour_space)
        )
    intransform = transforms.Compose(intransform_funs)
    outtransform_funs = []
    args.inv_func = None
    if args.colour_space is not None:
        outtransform_funs.append(
            cv2_preprocessing.ColourSpaceTransformation(args.colour_space)
        )
        if args.vis_rgb:
            args.inv_func = lambda x: generic_inv_fun(x, args.colour_space)
    outtransform = transforms.Compose(outtransform_funs)

    if args.data_dir is not None:
        args.train_dir = os.path.join(args.data_dir, 'train')
        args.validation_dir = os.path.join(args.data_dir, 'validation')
    else:
        args.train_dir = args.train_dir
        args.validation_dir = args.validation_dir
    kwargs = {'num_workers': args.workers,
              'pin_memory': True} if args.cuda else {}
    args.vis_func = vae_util.grid_save_reconstructed_images
    if args.colour_space == 'labhue':
        args.vis_func = vae_util.grid_save_reconstructed_labhue
    if args.dataset == 'coco':
        train_loader = panoptic_utils.get_coco_train(
            args.batch_size, args.opts, args.cfg_file
        )
        test_loader = panoptic_utils.get_coco_test(
            args.batch_size, args.opts, args.cfg_file
        )
    elif 'voc' in args.dataset:
        train_loader = torch.utils.data.DataLoader(
            data_loaders.VOCSegmentation(
                root=args.data_dir,
                image_set='train',
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            data_loaders.VOCSegmentation(
                root=args.data_dir,
                image_set='val',
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )
    elif args.category is not None:
        train_loader = torch.utils.data.DataLoader(
            data_loaders.CategoryImages(
                root=args.train_dir,
                category=args.category,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            data_loaders.CategoryImages(
                root=args.validation_dir,
                category=args.category,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )
    elif args.dataset == 'bsds':
        args.vis_func = vae_util.grid_save_reconstructed_bsds
        train_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.data_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.data_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )
    elif args.dataset == 'celeba':
        train_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.data_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                split='train',
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.data_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                split='test',
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.train_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_train_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            datasets_classes[args.dataset](
                root=args.validation_dir,
                intransform=intransform,
                outtransform=outtransform,
                transform=dataset_transforms[args.dataset],
                **dataset_test_args[args.dataset]
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs
        )

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    for epoch in range(args.start_epoch, args.epochs):
        train_losses = train(
            epoch, model, train_loader, optimizer, args.cuda, args.log_interval,
            save_path, args
        )
        test_losses = test_net(epoch, model, test_loader, args.cuda, save_path,
                               args)
        for k in train_losses.keys():
            name = k.replace('_train', '')
            train_name = k
            test_name = k.replace('train', 'test')
            writer.add_scalars(
                name, {'train': train_losses[train_name],
                       'test': test_losses[test_name]}, epoch
            )
        scheduler.step()
        vae_util.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'arch': {'k': args.k, 'hidden': args.hidden}
            },
            save_path
        )
Пример #8
0
def main(args):
    args = parse_arguments(args)
    weights_net = torch.load(args.model_path, map_location='cpu')

    args.in_colour_space = args.colour_space[:3]
    args.out_colour_space = args.colour_space[4:]

    args.outs_dict = dict()
    args.outs_dict[args.out_colour_space] = {'shape': [1, 1, 3]}
    from segmentation_models import unet
    network = unet.model.Unet(in_channels=3,
                              encoder_weights=None,
                              outs_dict=args.outs_dict,
                              classes=3)

    network.load_state_dict(weights_net)
    if args.exclude > 0:
        which_vec = [args.exclude - 1]
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    elif args.exclude < 0:
        which_vec = [*range(8)]
        which_vec.remove(abs(args.exclude) - 1)
        print(which_vec)
        network.state_dict()['emb.weight'][:, which_vec] = 0
    network.cuda()
    network.eval()

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)
    transform_funcs = transforms.Compose([
        cv2_transforms.Resize(args.target_size + 32),
        cv2_transforms.CenterCrop(args.target_size),
        cv2_transforms.ToTensor(),
        cv2_transforms.Normalize(mean, std)
    ])

    intransform_funs = []
    if args.in_colour_space != ' rgb':
        intransform_funs.append(
            cv2_preprocessing.ColourSpaceTransformation(args.in_colour_space))
    if args.corr_noise:
        parameters = dict()
        parameters['function'] = imutils.s_p_noise
        parameters['kwargs'] = {'amount': args.corr_noise, 'eq_chns': True}
        intransform_funs.append(
            cv2_preprocessing.PredictionTransformation(parameters))
    intransform = transforms.Compose(intransform_funs)

    if args.dataset == 'imagenet':
        test_loader = torch.utils.data.DataLoader(data_loaders.ImageFolder(
            root=args.validation_dir,
            intransform=intransform,
            outtransform=None,
            transform=transform_funcs),
                                                  batch_size=args.batch_size,
                                                  shuffle=False)
    elif args.dataset == 'celeba':
        test_loader = torch.utils.data.DataLoader(data_loaders.CelebA(
            root=args.validation_dir,
            intransform=intransform,
            outtransform=None,
            transform=transform_funcs,
            split='test'),
                                                  batch_size=args.batch_size,
                                                  shuffle=False)
    else:
        test_loader = torch.utils.data.DataLoader(
            data_loaders.CategoryImages(
                root=args.validation_dir,
                # FIXME
                category=args.category,
                intransform=intransform,
                outtransform=None,
                transform=transform_funcs),
            batch_size=args.batch_size,
            shuffle=False)
    export(test_loader, network, mean, std, args)
def main_worker(ngpus_per_node, args):
    mean, std = model_utils.get_preprocessing_function(args.colour_space,
                                                       args.vision_type)

    # preparing the output folder
    create_dir(args.out_dir)

    if args.gpus is not None:
        print("Use GPU: {} for training".format(args.gpus))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + args.gpus
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model
    if args.transfer_weights is not None:
        print('Transferred model!')
        model = contrast_utils.AFCModel(args.network_name,
                                        args.transfer_weights)
    elif args.custom_arch:
        print('Custom model!')
        supported_customs = ['resnet_basic_custom', 'resnet_bottleneck_custom']
        if args.network_name in supported_customs:
            model = custom_models.__dict__[args.network_name](
                args.blocks,
                pooling_type=args.pooling_type,
                in_chns=len(mean),
                num_classes=args.num_classes,
                inplanes=args.num_kernels,
                kernel_size=args.kernel_size)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.network_name))
        model = models.__dict__[args.network_name](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.network_name))
        model = models.__dict__[args.network_name]()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpus is not None:
            torch.cuda.set_device(args.gpus)
            model.cuda(args.gpus)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpus])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpus is not None:
        torch.cuda.set_device(args.gpus)
        model = model.cuda(args.gpus)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if (args.network_name.startswith('alexnet')
                or args.network_name.startswith('vgg')):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = soft_cross_entropy

    # optimiser
    if args.transfer_weights is None:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        params_to_optimize = [
            {
                'params': [p for p in model.parameters() if p.requires_grad]
            },
        ]
        optimizer = torch.optim.SGD(params_to_optimize,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    model_progress = []
    model_progress_path = os.path.join(args.out_dir, 'model_progress.csv')
    # optionally resume from a checkpoint
    # TODO: it would be best if resume load the architecture from this file
    # TODO: merge with which_architecture
    best_acc1 = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.initial_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            if args.gpus is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpus)
                model = model.cuda(args.gpus)
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if os.path.exists(model_progress_path):
                model_progress = np.loadtxt(model_progress_path, delimiter=',')
                model_progress = model_progress.tolist()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    train_trans = []
    valid_trans = []
    both_trans = []
    if args.mosaic_pattern is not None:
        mosaic_trans = preprocessing.mosaic_transformation(args.mosaic_pattern)
        both_trans.append(mosaic_trans)

    if args.num_augmentations != 0:
        augmentations = preprocessing.random_augmentation(
            args.augmentation_settings, args.num_augmentations)
        train_trans.append(augmentations)

    target_size = default_configs.get_default_target_size(
        args.dataset, args.target_size)

    final_trans = [
        cv2_transforms.ToTensor(),
        cv2_transforms.Normalize(mean, std),
    ]

    train_trans.append(
        cv2_transforms.RandomResizedCrop(target_size, scale=(0.08, 1.0)))

    # loading the training set
    train_trans = torch_transforms.Compose(
        [*both_trans, *train_trans, *final_trans])
    train_dataset = image_quality.BAPPS2afc(root=args.data_dir,
                                            split='train',
                                            transform=train_trans,
                                            concat=0.5)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    valid_trans.extend([
        cv2_transforms.Resize(target_size),
        cv2_transforms.CenterCrop(target_size),
    ])

    # loading validation set
    valid_trans = torch_transforms.Compose(
        [*both_trans, *valid_trans, *final_trans])
    validation_dataset = image_quality.BAPPS2afc(root=args.data_dir,
                                                 split='val',
                                                 transform=valid_trans,
                                                 concat=0)

    val_loader = torch.utils.data.DataLoader(validation_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # training on epoch
    for epoch in range(args.initial_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        misc_utils.adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train_log = train_on_data(train_loader, model, criterion, optimizer,
                                  epoch, args)

        # evaluate on validation set
        validation_log = validate_on_data(val_loader, model, criterion, args)

        model_progress.append([*train_log, *validation_log])

        # remember best acc@1 and save checkpoint
        acc1 = validation_log[2]
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if misc_utils.is_saving_node(args.multiprocessing_distributed,
                                     args.rank, ngpus_per_node):
            misc_utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.network_name,
                    'customs': {
                        'pooling_type': args.pooling_type,
                        'in_chns': len(mean),
                        'num_classes': args.num_classes,
                        'blocks': args.blocks,
                        'num_kernels': args.num_kernels,
                        'kernel_size': args.kernel_size
                    },
                    'transfer_weights': args.transfer_weights,
                    'preprocessing': {
                        'mean': mean,
                        'std': std
                    },
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                    'target_size': target_size,
                },
                is_best,
                out_folder=args.out_dir)
            # TODO: get this header directly as a dictionary keys
            header = 'epoch,t_time,t_loss,t_top5,v_time,v_loss,v_top1'
            np.savetxt(model_progress_path,
                       np.array(model_progress),
                       delimiter=',',
                       header=header)
def main(args):
    args = arguments.parse_arguments(args)

    # determining the number of input channels
    args.in_chns = 3

    out_chns = 3
    args.out_chns = out_chns

    args.mean = 0.5
    args.std = 0.5
    target_size = args.target_size or dataset_target_size[args.dataset]

    if args.dataset == 'ccvr':
        pre_shared_transforms = [
            cv2_transforms.Resize(target_size + 32),
            cv2_transforms.RandomCrop(target_size),
        ]
    else:
        pre_shared_transforms = [
            cv2_transforms.Resize(target_size + 32),
            cv2_transforms.CenterCrop(target_size),
        ]
    post_shared_transforms = [
        cv2_transforms.ToTensor(),
        cv2_transforms.Normalize(args.mean, args.std)
    ]

    pre_dataset_transforms = dict()
    post_dataset_transforms = dict()
    for key in datasets_classes.keys():
        pre_dataset_transforms[key] = transforms.Compose(pre_shared_transforms)
        post_dataset_transforms[key] = transforms.Compose(
            post_shared_transforms)

    save_path = vae_util.setup_logging_from_args(args)
    writer = SummaryWriter(save_path)

    torch.manual_seed(args.seed)
    cudnn.benchmark = True
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.pred is not None:
        checkpoint = torch.load(args.pred, map_location='cpu')
        model_vae = model_vqvae.DecomposeNet(**checkpoint['arch_params'])
        model_vae.load_state_dict(checkpoint['state_dict'])
    else:
        # FIXME: archs_param should be added to resume and fine_tune
        arch_params = {'k': args.k, 'd': args.d, 'hidden': args.hidden}
        model_vae = model_vqvae.DecomposeNet(hidden=args.hidden,
                                             k=args.k,
                                             d=args.d,
                                             in_chns=args.in_chns,
                                             out_chns=args.out_chns)
    model_vae = model_vae.cuda()

    # FIXME make it only for one single output
    if args.lab_init:
        distortion = [
            116.0 / 500, 16.0 / 500, 500.0 / 500, 200.0 / 500, 0.2068966
        ]
        trans_mat = [[0.412453, 0.357580, 0.180423],
                     [0.212671, 0.715160, 0.072169],
                     [0.019334, 0.119193, 0.950227]]

        ref_white = (0.95047, 1., 1.08883)

        tmat = colour_spaces.dkl_from_rgb.T
        tmat = np.expand_dims(tmat, [2, 3])
        cst_lr = args.lr * 0.1
    else:
        trans_mat = None
        ref_white = None
        distortion = None
        tmat = None
        cst_lr = args.lr
    # model_cst = ColourTransformer.LabTransformer(
    #     trans_mat=trans_mat, ref_white=ref_white,
    #     distortion=distortion, linear=args.linear
    # )
    model_cst = ColourTransformer.ResNetTransformer(layers=args.cst_layers)
    model_cst = model_cst.cuda()

    vae_params = [
        {
            'params': [p for p in model_vae.parameters() if p.requires_grad]
        },
    ]
    cst_params = [
        {
            'params': [p for p in model_cst.parameters() if p.requires_grad]
        },
    ]
    optimizer_vae = optim.Adam(vae_params, lr=args.lr)
    optimizer_cst = optim.Adam(cst_params, lr=cst_lr)
    scheduler_vae = optim.lr_scheduler.StepLR(optimizer_vae,
                                              int(args.epochs / 3), 0.5)
    scheduler_cst = optim.lr_scheduler.StepLR(optimizer_cst,
                                              int(args.epochs / 3), 0.5)

    if args.resume is not None:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_vae.load_state_dict(checkpoint['state_dict'])
        model_vae = model_vae.cuda()
        args.start_epoch = checkpoint['epoch'] + 1
        scheduler_vae.load_state_dict(checkpoint['scheduler_vae'])
        optimizer_vae.load_state_dict(checkpoint['optimizer_vae'])
        scheduler_cst.load_state_dict(checkpoint['scheduler_cst'])
        optimizer_cst.load_state_dict(checkpoint['optimizer_cst'])
    elif args.fine_tune is not None:
        weights = torch.load(args.fine_tune, map_location='cpu')
        model_vae.load_state_dict(weights, strict=False)
        model_vae = model_vae.cuda()

    intransform_funs = []
    if args.in_space.lower() == 'cgi':
        augmentation_settings = [{
            'function': random_imutils.adjust_contrast,
            'kwargs': {
                'amount': np.array([0.2, 1.0]),
                'channel_wise': True
            }
        }, {
            'function': random_imutils.adjust_gamma,
            'kwargs': {
                'amount': np.array([0.2, 5.0]),
                'channel_wise': True
            }
        }, {
            'function': random_imutils.adjust_illuminant,
            'kwargs': {
                'illuminant': np.array([0.0, 1.0])
            }
        }]
        intransform_funs.append(
            cv2_preprocessing.RandomAugmentationTransformation(
                augmentation_settings, num_augmentations=1))
    elif args.in_space.lower() != 'rgb':
        intransform_funs.append(
            cv2_preprocessing.DecompositionTransformation(
                args.in_space.lower()))
    intransform = transforms.Compose(intransform_funs)

    outtransform = None

    args.outs_dict = {'rgb': {'vis_fun': None}}

    # preparing the dataset
    transforms_kwargs = {
        'intransform': intransform,
        'outtransform': outtransform,
        'pre_transform': pre_dataset_transforms[args.dataset],
        'post_transform': post_dataset_transforms[args.dataset]
    }
    if args.dataset in ['celeba', 'touch', 'ccvr']:
        train_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                       split='train',
                                                       **transforms_kwargs)
        test_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                      split='test',
                                                      **transforms_kwargs)
    elif args.dataset in ['coco']:
        train_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                       split='train',
                                                       **transforms_kwargs)
        test_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                      split='val',
                                                      **transforms_kwargs)
    elif args.dataset in ['voc']:
        train_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                       image_set='train',
                                                       **transforms_kwargs)
        test_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                      image_set='val',
                                                      **transforms_kwargs)
    else:
        train_dataset = datasets_classes[args.dataset](root=os.path.join(
            args.data_dir, 'train'),
                                                       **transforms_kwargs)
        test_dataset = datasets_classes[args.dataset](root=os.path.join(
            args.data_dir, 'validation'),
                                                      **transforms_kwargs)

    loader_kwargs = {
        'batch_size': args.batch_size,
        'num_workers': args.workers,
        'pin_memory': True
    }
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               shuffle=True,
                                               **loader_kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              shuffle=False,
                                              **loader_kwargs)

    if args.pred is not None:
        predict(model_vae, test_loader, save_path, args)
        return

    # starting to train
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    for epoch in range(args.start_epoch, args.epochs):
        train_losses = train(epoch, model_vae, model_cst, train_loader,
                             (optimizer_vae, optimizer_cst), save_path, args)
        test_losses = test_net(epoch, model_vae, model_cst, test_loader,
                               save_path, args)
        for k in train_losses.keys():
            name = k.replace('_trn', '')
            train_name = k
            test_name = k.replace('_trn', '_val')
            writer.add_scalars(name, {
                'train': train_losses[train_name],
                'test': test_losses[test_name]
            }, epoch)
        scheduler_vae.step()
        scheduler_cst.step()
        vae_util.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model_vae.state_dict(),
                'colour_transformer': model_cst.state_dict(),
                'optimizer_vae': optimizer_vae.state_dict(),
                'scheduler_vae': scheduler_vae.state_dict(),
                'optimizer_cst': optimizer_cst.state_dict(),
                'scheduler_cst': scheduler_cst.state_dict(),
                'arch': args.model,
                'arch_params': {
                    **arch_params,
                    'in_chns': args.in_chns,
                    'out_chns': args.out_chns,
                },
                'transformer_params': {
                    'linear': args.linear
                }
            }, save_path)
Пример #11
0
def main(args):
    args = arguments.parse_arguments(args)

    # determining the number of input channels
    in_size = 1
    if args.in_space == 'gry':
        args.in_chns = 1
    elif args.in_space == 'db1':
        args.in_chns = 4
        in_size = 0.5
    else:
        args.in_chns = 3

    # FIXME
    # preparing the output dictionary
    args.outs_dict = dict()
    for out_type in args.outputs:
        if out_type == 'input':
            out_shape = [1 / in_size, 1 / in_size, args.in_chns]
        elif out_type == 'gry':
            out_shape = [1 / in_size, 1 / in_size, 1]
        elif out_type == 'db1':
            # TODO: just assuming numbers of square 2
            out_shape = [0.5 / in_size, 0.5 / in_size, 4]
        else:
            out_shape = [1 / in_size, 1 / in_size, 3]
        args.outs_dict[out_type] = {'shape': out_shape}

    args.mean = 0.5
    args.std = 0.5
    target_size = args.target_size or dataset_target_size[args.dataset]

    if args.dataset == 'ccvr':
        pre_shared_transforms = [
            cv2_transforms.Resize(target_size + 32),
            cv2_transforms.RandomCrop(target_size),
        ]
    else:
        pre_shared_transforms = [
            cv2_transforms.Resize(target_size + 32),
            cv2_transforms.CenterCrop(target_size),
        ]
    post_shared_transforms = [
        cv2_transforms.ToTensor(),
        cv2_transforms.Normalize(args.mean, args.std)
    ]

    pre_dataset_transforms = dict()
    post_dataset_transforms = dict()
    for key in datasets_classes.keys():
        pre_dataset_transforms[key] = transforms.Compose(pre_shared_transforms)
        post_dataset_transforms[key] = transforms.Compose(
            post_shared_transforms)

    save_path = vae_util.setup_logging_from_args(args)
    writer = SummaryWriter(save_path)

    torch.manual_seed(args.seed)
    cudnn.benchmark = True
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.pred is not None:
        checkpoint = torch.load(args.pred, map_location='cpu')
        vae_model = (model_single
                     if checkpoint['model'] == 'single' else model_multi)
        model = vae_model.DecomposeNet(**checkpoint['arch_params'])
        model.load_state_dict(checkpoint['state_dict'])
    else:
        # FIXME: add to prediction, right now it's hard-coded for resnet18
        if args.model == 'deeplabv3':
            backbone = {
                'arch': 'resnet_bottleneck_custom',
                'customs': {
                    'pooling_type': 'max',
                    'in_chns': args.in_chns,
                    'blocks': [2, 2, 2, 2],
                    'num_kernels': 64,
                    'num_classes': 1000
                }
            }
            # FIXME out_shape is defined far above, this is just a hack
            arch_params = {'backbone': backbone, 'num_classes': out_shape[-1]}
            model = model_segmentation.deeplabv3_resnet(
                backbone, num_classes=out_shape[-1], outs_dict=args.outs_dict)
        elif 'unet' in args.model:
            from segmentation_models import unet
            encoder_name = args.model.split('_')[-1]
            model = unet.model.Unet(in_channels=args.in_chns,
                                    encoder_name=encoder_name,
                                    encoder_weights=None,
                                    outs_dict=args.outs_dict,
                                    classes=out_shape[-1])
            arch_params = {'encoder_name': encoder_name}
        elif args.model == 'category':
            # FIXME: archs_param should be added to resume and fine_tune
            arch_params = {'k': args.k, 'd': args.d, 'hidden': args.hidden}
            vae_model = model_category
            model = vae_model.DecomposeNet(hidden=args.hidden,
                                           k=args.k,
                                           d=args.d,
                                           in_chns=args.in_chns,
                                           outs_dict=args.outs_dict)
        else:
            # FIXME: archs_param should be added to resume and fine_tune
            arch_params = {'k': args.k, 'd': args.d, 'hidden': args.hidden}
            vae_model = model_single if args.model == 'single' else model_multi
            model = vae_model.DecomposeNet(hidden=args.hidden,
                                           k=args.k,
                                           d=args.d,
                                           in_chns=args.in_chns,
                                           outs_dict=args.outs_dict)
    model = model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, int(args.epochs / 3), 0.5)

    if args.resume is not None:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'])
        model = model.cuda()
        args.start_epoch = checkpoint['epoch'] + 1
        scheduler.load_state_dict(checkpoint['scheduler'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    elif args.fine_tune is not None:
        weights = torch.load(args.fine_tune, map_location='cpu')
        model.load_state_dict(weights, strict=False)
        model = model.cuda()

    intransform_funs = []
    if args.in_space.lower() == 'cgi':
        augmentation_settings = [{
            'function': random_imutils.adjust_contrast,
            'kwargs': {
                'amount': np.array([0.2, 1.0]),
                'channel_wise': True
            }
        }, {
            'function': random_imutils.adjust_gamma,
            'kwargs': {
                'amount': np.array([0.2, 5.0]),
                'channel_wise': True
            }
        }, {
            'function': random_imutils.adjust_illuminant,
            'kwargs': {
                'illuminant': np.array([0.0, 1.0])
            }
        }]
        intransform_funs.append(
            cv2_preprocessing.RandomAugmentationTransformation(
                augmentation_settings, num_augmentations=1))
    elif args.in_space.lower() != 'rgb':
        intransform_funs.append(
            cv2_preprocessing.DecompositionTransformation(
                args.in_space.lower()))
    intransform = transforms.Compose(intransform_funs)

    outtransform_funs = [
        cv2_preprocessing.MultipleOutputTransformation(args.outputs)
    ]
    outtransform = transforms.Compose(outtransform_funs)

    # FIXME
    for out_type in args.outputs:
        if out_type == 'input':
            vis_fun = None
        elif out_type == 'gry':
            vis_fun = None
        elif out_type == 'db1':
            vis_fun = vae_util.wavelet_visualise
        elif args.vis_rgb:
            vis_fun = partial(all2rgb, src_space=out_type)
        else:
            vis_fun = None
        args.outs_dict[out_type]['vis_fun'] = vis_fun

    # preparing the dataset
    transforms_kwargs = {
        'intransform': intransform,
        'outtransform': outtransform,
        'pre_transform': pre_dataset_transforms[args.dataset],
        'post_transform': post_dataset_transforms[args.dataset]
    }
    if args.dataset in ['celeba', 'touch', 'ccvr']:
        train_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                       split='train',
                                                       **transforms_kwargs)
        test_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                      split='test',
                                                      **transforms_kwargs)
    elif args.dataset in ['coco']:
        train_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                       split='train',
                                                       **transforms_kwargs)
        test_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                      split='val',
                                                      **transforms_kwargs)
    elif args.dataset in ['voc']:
        train_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                       image_set='train',
                                                       **transforms_kwargs)
        test_dataset = datasets_classes[args.dataset](root=args.data_dir,
                                                      image_set='val',
                                                      **transforms_kwargs)
    else:
        train_dataset = datasets_classes[args.dataset](root=os.path.join(
            args.data_dir, 'train'),
                                                       **transforms_kwargs)
        test_dataset = datasets_classes[args.dataset](root=os.path.join(
            args.data_dir, 'validation'),
                                                      **transforms_kwargs)

    loader_kwargs = {
        'batch_size': args.batch_size,
        'num_workers': args.workers,
        'pin_memory': True
    }
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               shuffle=True,
                                               **loader_kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              shuffle=False,
                                              **loader_kwargs)

    if args.pred is not None:
        predict(model, test_loader, save_path, args)
        return

    # starting to train
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    for epoch in range(args.start_epoch, args.epochs):
        train_losses = train(epoch, model, train_loader, optimizer, save_path,
                             args)
        test_losses = test_net(epoch, model, test_loader, save_path, args)
        for k in train_losses.keys():
            name = k.replace('_train', '')
            train_name = k
            test_name = k.replace('train', 'test')
            writer.add_scalars(name, {
                'train': train_losses[train_name],
                'test': test_losses[test_name]
            }, epoch)
        scheduler.step()
        vae_util.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'arch': args.model,
                'arch_params': {
                    **arch_params, 'in_chns': args.in_chns,
                    'outs_dict': args.outs_dict
                }
            }, save_path)