Exemplo n.º 1
0
def prepare_transformations_train(dataset_name,
                                  colour_transformations,
                                  other_transformations,
                                  chns_transformation,
                                  normalize,
                                  target_size,
                                  random_labels=False):
    if 'cifar' in dataset_name or dataset_name in folder_dbs:
        flip_p = 0.5
        if random_labels:
            size_transform = cv2_transforms.Resize(target_size)
            flip_p = -1
        elif 'cifar' in dataset_name:
            size_transform = cv2_transforms.RandomCrop(target_size, padding=4)
        elif 'imagenet' in dataset_name or 'ecoset' in dataset_name:
            scale = (0.08, 1.0)
            size_transform = cv2_transforms.RandomResizedCrop(target_size,
                                                              scale=scale)
        else:
            scale = (0.50, 1.0)
            size_transform = cv2_transforms.RandomResizedCrop(target_size,
                                                              scale=scale)
        transformations = torch_transforms.Compose([
            size_transform,
            *colour_transformations,
            *other_transformations,
            cv2_transforms.RandomHorizontalFlip(p=flip_p),
            cv2_transforms.ToTensor(),
            *chns_transformation,
            normalize,
        ])
    elif 'wcs_lms' in dataset_name:
        # FIXME: colour transformation in lms is different from rgb or lab
        transformations = torch_transforms.Compose([
            *other_transformations,
            RandomHorizontalFlip(),
            Numpy2Tensor(),
            *chns_transformation,
            normalize,
        ])
    elif 'wcs_jpg' in dataset_name:
        transformations = torch_transforms.Compose([
            *colour_transformations,
            *other_transformations,
            cv2_transforms.RandomHorizontalFlip(),
            cv2_transforms.ToTensor(),
            *chns_transformation,
            normalize,
        ])
    else:
        sys.exit('Transformations for dataset %s is not supported.' %
                 dataset_name)
    return transformations
Exemplo n.º 2
0
def train_set(db, target_size, mean, std, extra_transformation=None, **kwargs):
    if extra_transformation is None:
        extra_transformation = []
    if kwargs['train_params'] is None:
        shared_pre_transforms = [
            *extra_transformation,
            cv2_transforms.RandomHorizontalFlip(),
        ]
    else:
        shared_pre_transforms = [*extra_transformation]
    shared_post_transforms = _get_shared_post_transforms(mean, std)
    if db in NATURAL_DATASETS:
        # if train params are passed don't use any random processes
        if kwargs['train_params'] is None:
            scale = (0.08, 1.0)
            size_transform = cv2_transforms.RandomResizedCrop(target_size,
                                                              scale=scale)
            pre_transforms = [size_transform, *shared_pre_transforms]
        else:
            pre_transforms = [
                cv2_transforms.Resize(target_size),
                cv2_transforms.CenterCrop(target_size), *shared_pre_transforms
            ]
        post_transforms = [*shared_post_transforms]
        return _natural_dataset(db, 'train', pre_transforms, post_transforms,
                                **kwargs)
    elif db in ['gratings']:
        return _get_grating_dataset(shared_pre_transforms,
                                    shared_post_transforms, target_size,
                                    **kwargs)
    return None
Exemplo n.º 3
0
def get_train_dataset(train_dir, target_size, preprocess):
    mean, std = preprocess
    normalise = cv2_transforms.Normalize(mean=mean, std=std)
    scale = (0.08, 1.0)
    size_transform = cv2_transforms.RandomResizedCrop(target_size, scale=scale)
    transform = torch_transforms.Compose([
        size_transform,
        cv2_transforms.RandomHorizontalFlip(),
        cv2_transforms.ToTensor(),
        normalise,
    ])
    train_dataset = ImageFolder({'root': train_dir, 'transform': transform})
    return train_dataset
def main_worker(ngpus_per_node, args):
    mean, std = model_utils.get_preprocessing_function(args.colour_space,
                                                       args.vision_type)

    # preparing the output folder
    create_dir(args.out_dir)

    if args.gpus is not None:
        print("Use GPU: {} for training".format(args.gpus))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + args.gpus
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model
    if args.transfer_weights is not None:
        print('Transferred model!')
        model = contrast_utils.AFCModel(args.network_name,
                                        args.transfer_weights)
    elif args.custom_arch:
        print('Custom model!')
        supported_customs = ['resnet_basic_custom', 'resnet_bottleneck_custom']
        if args.network_name in supported_customs:
            model = custom_models.__dict__[args.network_name](
                args.blocks,
                pooling_type=args.pooling_type,
                in_chns=len(mean),
                num_classes=args.num_classes,
                inplanes=args.num_kernels,
                kernel_size=args.kernel_size)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.network_name))
        model = models.__dict__[args.network_name](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.network_name))
        model = models.__dict__[args.network_name]()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpus is not None:
            torch.cuda.set_device(args.gpus)
            model.cuda(args.gpus)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpus])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpus is not None:
        torch.cuda.set_device(args.gpus)
        model = model.cuda(args.gpus)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if (args.network_name.startswith('alexnet')
                or args.network_name.startswith('vgg')):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = soft_cross_entropy

    # optimiser
    if args.transfer_weights is None:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        params_to_optimize = [
            {
                'params': [p for p in model.parameters() if p.requires_grad]
            },
        ]
        optimizer = torch.optim.SGD(params_to_optimize,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    model_progress = []
    model_progress_path = os.path.join(args.out_dir, 'model_progress.csv')
    # optionally resume from a checkpoint
    # TODO: it would be best if resume load the architecture from this file
    # TODO: merge with which_architecture
    best_acc1 = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.initial_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            if args.gpus is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpus)
                model = model.cuda(args.gpus)
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if os.path.exists(model_progress_path):
                model_progress = np.loadtxt(model_progress_path, delimiter=',')
                model_progress = model_progress.tolist()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    train_trans = []
    valid_trans = []
    both_trans = []
    if args.mosaic_pattern is not None:
        mosaic_trans = preprocessing.mosaic_transformation(args.mosaic_pattern)
        both_trans.append(mosaic_trans)

    if args.num_augmentations != 0:
        augmentations = preprocessing.random_augmentation(
            args.augmentation_settings, args.num_augmentations)
        train_trans.append(augmentations)

    target_size = default_configs.get_default_target_size(
        args.dataset, args.target_size)

    final_trans = [
        cv2_transforms.ToTensor(),
        cv2_transforms.Normalize(mean, std),
    ]

    train_trans.append(
        cv2_transforms.RandomResizedCrop(target_size, scale=(0.08, 1.0)))

    # loading the training set
    train_trans = torch_transforms.Compose(
        [*both_trans, *train_trans, *final_trans])
    train_dataset = image_quality.BAPPS2afc(root=args.data_dir,
                                            split='train',
                                            transform=train_trans,
                                            concat=0.5)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    valid_trans.extend([
        cv2_transforms.Resize(target_size),
        cv2_transforms.CenterCrop(target_size),
    ])

    # loading validation set
    valid_trans = torch_transforms.Compose(
        [*both_trans, *valid_trans, *final_trans])
    validation_dataset = image_quality.BAPPS2afc(root=args.data_dir,
                                                 split='val',
                                                 transform=valid_trans,
                                                 concat=0)

    val_loader = torch.utils.data.DataLoader(validation_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # training on epoch
    for epoch in range(args.initial_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        misc_utils.adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train_log = train_on_data(train_loader, model, criterion, optimizer,
                                  epoch, args)

        # evaluate on validation set
        validation_log = validate_on_data(val_loader, model, criterion, args)

        model_progress.append([*train_log, *validation_log])

        # remember best acc@1 and save checkpoint
        acc1 = validation_log[2]
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if misc_utils.is_saving_node(args.multiprocessing_distributed,
                                     args.rank, ngpus_per_node):
            misc_utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.network_name,
                    'customs': {
                        'pooling_type': args.pooling_type,
                        'in_chns': len(mean),
                        'num_classes': args.num_classes,
                        'blocks': args.blocks,
                        'num_kernels': args.num_kernels,
                        'kernel_size': args.kernel_size
                    },
                    'transfer_weights': args.transfer_weights,
                    'preprocessing': {
                        'mean': mean,
                        'std': std
                    },
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                    'target_size': target_size,
                },
                is_best,
                out_folder=args.out_dir)
            # TODO: get this header directly as a dictionary keys
            header = 'epoch,t_time,t_loss,t_top5,v_time,v_loss,v_top1'
            np.savetxt(model_progress_path,
                       np.array(model_progress),
                       delimiter=',',
                       header=header)