Exemple #1
0
def main_worker(ngpus_per_node, args):
    mean, std = model_utils.get_preprocessing_function(args.colour_space,
                                                       args.vision_type)

    if args.gpus is not None:
        print("Use GPU: {} for training".format(args.gpus))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + args.gpus
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model
    if args.transfer_weights is not None:
        print('Transferred model!')
        (model, _) = model_utils.which_network(args.transfer_weights[0],
                                               args.task_type,
                                               num_classes=args.old_classes)
        which_layer = -1
        if len(args.transfer_weights) == 2:
            which_layer = args.transfer_weights[1]
        model = model_utils.NewClassificationModel(model, which_layer,
                                                   args.num_classes)
    elif args.custom_arch:
        print('Custom model!')
        supported_customs = ['resnet_basic_custom', 'resnet_bottleneck_custom']
        if os.path.isfile(args.network_name):
            checkpoint = torch.load(args.network_name, map_location='cpu')
            customs = None
            if 'customs' in checkpoint:
                customs = checkpoint['customs']
                # TODO: num_classes is just for backward compatibility
                if 'num_classes' not in customs:
                    customs['num_classes'] = 1000
            model = which_architecture(checkpoint['arch'], customs,
                                       args.contrast_head)
            args.network_name = checkpoint['arch']

            model.load_state_dict(checkpoint['state_dict'], strict=False)
        elif args.network_name in supported_customs:
            model = custom_models.__dict__[args.network_name](
                args.blocks,
                contrast_head=args.contrast_head,
                pooling_type=args.pooling_type,
                in_chns=len(mean),
                num_classes=args.num_classes,
                inplanes=args.num_kernels,
                kernel_size=args.kernel_size)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.network_name))
        model = models.__dict__[args.network_name](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.network_name))
        model = models.__dict__[args.network_name]()

    # TODO: why load weights is False?
    args.out_dir = prepare_training.prepare_output_directories(
        dataset_name='contrast',
        network_name=args.network_name,
        optimiser='sgd',
        load_weights=False,
        experiment_name=args.experiment_name,
        framework='pytorch')
    # preparing the output folder
    create_dir(args.out_dir)
    json_file_name = os.path.join(args.out_dir, 'args.json')
    with open(json_file_name, 'w') as fp:
        json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpus is not None:
            torch.cuda.set_device(args.gpus)
            model.cuda(args.gpus)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpus])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpus is not None:
        torch.cuda.set_device(args.gpus)
        model = model.cuda(args.gpus)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if (args.network_name.startswith('alexnet')
                or args.network_name.startswith('vgg')):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpus)

    # optimiser
    if args.transfer_weights is None:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        # for p in model.features.parameters():
        #     p.requires_grad = False
        params_to_optimize = [
            {
                'params': [p for p in model.features.parameters()],
                'lr': 1e-6
            },
            {
                'params': [p for p in model.fc.parameters()]
            },
        ]
        optimizer = torch.optim.SGD(params_to_optimize,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    model_progress = []
    model_progress_path = os.path.join(args.out_dir, 'model_progress.csv')
    # optionally resume from a checkpoint
    # TODO: it would be best if resume load the architecture from this file
    # TODO: merge with which_architecture
    best_acc1 = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.initial_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            if args.gpus is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpus)
                model = model.cuda(args.gpus)
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if os.path.exists(model_progress_path):
                model_progress = np.loadtxt(model_progress_path, delimiter=',')
                model_progress = model_progress.tolist()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    train_trans = []
    valid_trans = []
    both_trans = []
    if args.mosaic_pattern is not None:
        mosaic_trans = preprocessing.mosaic_transformation(args.mosaic_pattern)
        both_trans.append(mosaic_trans)

    if args.num_augmentations != 0:
        augmentations = preprocessing.random_augmentation(
            args.augmentation_settings, args.num_augmentations)
        train_trans.append(augmentations)

    target_size = default_configs.get_default_target_size(
        args.dataset, args.target_size)

    # loading the training set
    train_trans = [*both_trans, *train_trans]
    db_params = {
        'colour_space': args.colour_space,
        'vision_type': args.vision_type,
        'mask_image': args.mask_image
    }
    if args.dataset in ['imagenet', 'celeba', 'natural']:
        path_or_sample = args.data_dir
    else:
        path_or_sample = args.train_samples
    train_dataset = dataloader.train_set(args.dataset,
                                         target_size,
                                         mean,
                                         std,
                                         extra_transformation=train_trans,
                                         data_dir=path_or_sample,
                                         **db_params)
    if args.dataset == 'natural':
        train_dataset.num_crops = args.batch_size
        args.batch_size = 1

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # loading validation set
    valid_trans = [*both_trans, *valid_trans]
    validation_dataset = dataloader.validation_set(
        args.dataset,
        target_size,
        mean,
        std,
        extra_transformation=valid_trans,
        data_dir=path_or_sample,
        **db_params)
    if args.dataset == 'natural':
        validation_dataset.num_crops = train_dataset.num_crops
        args.batch_size = 1

    val_loader = torch.utils.data.DataLoader(validation_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # training on epoch
    for epoch in range(args.initial_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        misc_utils.adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train_log = train_on_data(train_loader, model, criterion, optimizer,
                                  epoch, args)

        # evaluate on validation set
        validation_log = validate_on_data(val_loader, model, criterion, args)

        model_progress.append([*train_log, *validation_log])

        # remember best acc@1 and save checkpoint
        acc1 = validation_log[2]
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if misc_utils.is_saving_node(args.multiprocessing_distributed,
                                     args.rank, ngpus_per_node):
            misc_utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.network_name,
                    'customs': {
                        'pooling_type': args.pooling_type,
                        'in_chns': len(mean),
                        'num_classes': args.num_classes,
                        'blocks': args.blocks,
                        'num_kernels': args.num_kernels,
                        'kernel_size': args.kernel_size
                    },
                    'preprocessing': {
                        'mean': mean,
                        'std': std
                    },
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                    'target_size': target_size,
                },
                is_best,
                out_folder=args.out_dir)
            # TODO: get this header directly as a dictionary keys
            header = 'epoch,t_time,t_loss,t_top1,t_top5,v_time,v_loss,v_top1,v_top5'
            np.savetxt(model_progress_path,
                       np.array(model_progress),
                       delimiter=',',
                       header=header)
def main(args):
    args = parse_arguments(args)
    if args.imagenet_dir is None:
        args.imagenet_dir = '/home/arash/Software/imagenet/raw-data/validation/'
    colour_space = args.colour_space
    target_size = args.target_size
    if args.repeat:
        target_size = int(args.target_size / 2)

    mean, std = model_utils.get_preprocessing_function(colour_space,
                                                       'trichromat')
    extra_transformations = []
    if args.noise is not None:
        noise_kwargs = {'amount': float(args.noise[1])}
        extra_transformations.append(
            cv2_preprocessing.UniqueTransformation(imutils.gaussian_noise,
                                                   **noise_kwargs))
    if args.mosaic_pattern is not None:
        mosaic_trans = cv2_preprocessing.MosaicTransformation(
            args.mosaic_pattern)
        extra_transformations.append(mosaic_trans)

    # testing setting
    freqs = args.freqs
    if freqs is None:
        if args.model_fest:
            test_sfs = [
                107.558006181068, 60.2277887428434, 42.5666842683747,
                30.1176548488805, 21.2841900445655, 15.0589946730047,
                10.6611094334144, 7.5293662203794, 5.33055586478039,
                4.01568430264343
            ]
            args.gabor = 'model_fest'
        else:
            if target_size == 256:
                t4s = [
                    target_size / 2,
                    target_size / 2.5,
                    target_size / 3,
                    target_size / 3.5,
                    target_size / 3.75,
                    target_size / 4,
                ]
            else:
                # assuming 128
                t4s = [64]

            sf_base = ((target_size / 2) / np.pi)
            test_sfs = [
                sf_base / e for e in [
                    *np.arange(1, 21), *np.arange(21, 61, 5),
                    *np.arange(61, t4s[-1], 25), *t4s
                ]
            ]
    else:
        if len(freqs) == 3:
            test_sfs = np.linspace(freqs[0], freqs[1], int(freqs[2]))
        else:
            test_sfs = freqs
    # so the sfs gets sorted
    test_sfs = np.unique(test_sfs)
    contrasts = args.contrasts
    if contrasts is None:
        test_contrasts = [0.5]
    else:
        test_contrasts = contrasts
    test_thetas = np.linspace(0, np.pi, 7)
    test_rhos = np.linspace(0, np.pi, 4)
    test_ps = [0.0, 1.0]

    if args.pretrained:
        if args.side_by_side:
            if args.scale_factor is None:
                scale_factor = (args.target_size / 256)**2
            else:
                scale_factor = args.scale_factor
            model = pretrained_models.NewClassificationModel(
                args.model_path,
                grey_width=args.grey_width == 40,
                scale_factor=scale_factor)
        else:
            if args.scale_factor is None:
                scale_factor = (args.target_size / 128)**2
            else:
                scale_factor = args.scale_factor
            model = models_csf.ContrastDiscrimination(
                args.model_path,
                grey_width=args.grey_width == 40,
                scale_factor=scale_factor)
    else:
        model, _ = model_utils.which_network_classification(args.model_path, 2)
    model.eval()
    model.cuda()

    mean_std = None
    if args.visualise:
        mean_std = (mean, std)

    if args.avg_illuminant < 0:
        max_high = 1 + 2 * args.avg_illuminant
    elif args.avg_illuminant > 0:
        max_high = 1 + -2 * args.avg_illuminant
    else:
        max_high = 1.0
    mid_contrast = (0 + max_high) / 2

    all_results = None
    csf_flags = [mid_contrast for _ in test_sfs]

    if args.db == 'gratings':
        for i in range(len(csf_flags)):
            low = 0
            high = max_high
            j = 0
            while csf_flags[i] is not None:
                print('%.2d %.3d Doing %f - %f %f %f' %
                      (i, j, test_sfs[i], csf_flags[i], low, high))

                test_samples = {
                    'amp': [csf_flags[i]],
                    'lambda_wave': [test_sfs[i]],
                    'theta': test_thetas,
                    'rho': test_rhos,
                    'side': test_ps,
                    'avg_illuminant': args.avg_illuminant
                }
                db_params = {
                    'colour_space': colour_space,
                    'vision_type': args.vision_type,
                    'repeat': args.repeat,
                    'mask_image': args.gabor,
                    'grey_width': args.grey_width,
                    'side_by_side': args.side_by_side
                }

                db = dataloader.validation_set(args.db,
                                               target_size,
                                               mean,
                                               std,
                                               extra_transformations,
                                               data_dir=test_samples,
                                               **db_params)
                db.contrast_space = args.contrast_space

                db_loader = torch.utils.data.DataLoader(
                    db,
                    batch_size=args.batch_size,
                    shuffle=False,
                    num_workers=args.workers,
                    pin_memory=True,
                )

                if args.side_by_side:
                    new_results, all_results = run_gratings(
                        db_loader,
                        model,
                        args.out_file,
                        args.print,
                        mean_std=mean_std,
                        old_results=all_results)
                else:
                    new_results, all_results = run_gratings_separate(
                        db_loader,
                        model,
                        args.out_file,
                        args.print,
                        mean_std=mean_std,
                        old_results=all_results)
                new_contrast, low, high = sensitivity_sf(new_results,
                                                         test_sfs[i],
                                                         varname='all',
                                                         th=0.75,
                                                         low=low,
                                                         high=high)
                if (abs(csf_flags[i] - max_high) < 1e-3
                        or new_contrast == csf_flags[i] or j == 20):
                    print('had to skip', csf_flags[i])
                    csf_flags[i] = None
                else:
                    csf_flags[i] = new_contrast
                j += 1