Beispiel #1
0
def make_data(batch_size):
    print('Preparing data...', flush=True)

    if is_server():
        datadir = './.data/vision/imagenet'
    else:  # local settings
        datadir = '/fastwork/data/ilsvrc2012'

    # Setup the input pipeline
    _, crop = bit_hyperrule.get_resolution_from_dataset('imagenet2012')
    input_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    # valid_set = tv.datasets.ImageFolder(os.path.join(datadir, 'val'), input_tx)
    valid_set = tv.datasets.ImageNet(datadir, split='val', transform=input_tx)

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=8,
                                               pin_memory=True,
                                               drop_last=False)
    return valid_set, valid_loader
def mkval(args):
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)
    valid_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    path = args.datadir
    validate_csv_file = pjoin(path, 'metadata', 'validate_labels.csv')

    valid_set = SnakeDataset(path,
                             is_train=False,
                             transform=valid_tx,
                             target_transform=None,
                             csv_file=validate_csv_file)

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0,
                                               pin_memory=True,
                                               drop_last=False)

    return valid_set, valid_loader, valid_set.classes
Beispiel #3
0
def mktrainval():
    precrop, crop = bit_hyperrule.get_resolution_from_dataset("cifar10")
    train_tx = tv.transforms.Compose([
        tv.transforms.Resize((precrop, precrop)),
        tv.transforms.RandomCrop((crop, crop)),
        tv.transforms.RandomHorizontalFlip(),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    val_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    train_set = tv.datasets.ImageFolder(root=r'train', transform=train_tx)
    valid_set = tv.datasets.ImageFolder(root=r'test', transform=val_tx)

    # if args.examples_per_class is not None:
    #
    #   indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
    #   train_set = torch.utils.data.Subset(train_set, indices=indices)

    batch_size = 600

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=4,
                                               shuffle=True,
                                               num_workers=2,
                                               pin_memory=True,
                                               drop_last=False)

    if batch_size <= len(train_set):
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2,
                                                   pin_memory=True,
                                                   drop_last=False)
    else:
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=batch_size,
            num_workers=2,
            pin_memory=True,
            sampler=torch.utils.data.RandomSampler(train_set,
                                                   replacement=True,
                                                   num_samples=512))

    return train_set, valid_set, train_loader, valid_loader
Beispiel #4
0
def mkval(args):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)

    val_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if args.dataset == "cifar10":
        valid_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=val_tx,
                                        train=False,
                                        download=True)
    elif args.dataset == "cifar100":
        valid_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=val_tx,
                                         train=False,
                                         download=True)
    elif args.dataset == "imagenet2012":
        valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx)
    else:
        raise ValueError(f"Sorry, we have not spent time implementing the "
                         f"{args.dataset} dataset in the PyTorch codebase. "
                         f"In principle, it should be easy to add :)")

    if args.examples_per_class is not None:
        indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
        train_set = torch.utils.data.Subset(train_set, indices=indices)

    micro_batch_size = args.batch_size // args.batch_split

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    return valid_set, valid_loader
def main(args):
    tf.io.gfile.makedirs(args.logdir)
    logger = bit_common.setup_logger(args)

    logger.info(f'Available devices: {tf.config.list_physical_devices()}')

    tf.io.gfile.makedirs(args.bit_pretrained_dir)
    bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.h5')
    if not tf.io.gfile.exists(bit_model_file):
        model_url = models.KNOWN_MODELS[args.model]
        logger.info(f'Downloading the model from {model_url}...')
        tf.io.gfile.copy(model_url, bit_model_file)

    # Set up input pipeline
    dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train',
                                                   args.examples_per_class)

    # Distribute training
    strategy = tf.distribute.MirroredStrategy()
    num_devices = strategy.num_replicas_in_sync
    print('Number of devices: {}'.format(num_devices))

    resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset(
        args.dataset)
    data_train = input_pipeline.get_data(
        dataset=args.dataset,
        mode='train',
        repeats=None,
        batch_size=args.batch,
        resize_size=resize_size,
        crop_size=crop_size,
        examples_per_class=args.examples_per_class,
        examples_per_class_seed=args.examples_per_class_seed,
        mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']),
        num_devices=num_devices,
        tfds_manual_dir=args.tfds_manual_dir)
    data_test = input_pipeline.get_data(dataset=args.dataset,
                                        mode='test',
                                        repeats=1,
                                        batch_size=args.batch,
                                        resize_size=resize_size,
                                        crop_size=crop_size,
                                        examples_per_class=1,
                                        examples_per_class_seed=0,
                                        mixup_alpha=None,
                                        num_devices=num_devices,
                                        tfds_manual_dir=args.tfds_manual_dir)

    data_train = data_train.map(lambda x: reshape_for_keras(
        x, batch_size=args.batch, crop_size=crop_size))
    data_test = data_test.map(lambda x: reshape_for_keras(
        x, batch_size=args.batch, crop_size=crop_size))

    with strategy.scope():
        filters_factor = int(args.model[-1]) * 4
        model = models.ResnetV2(num_units=models.NUM_UNITS[args.model],
                                num_outputs=21843,
                                filters_factor=filters_factor,
                                name="resnet",
                                trainable=True,
                                dtype=tf.float32)

        model.build((None, None, None, 3))
        logger.info(f'Loading weights...')
        model.load_weights(bit_model_file)
        logger.info(f'Weights loaded into model!')

        model._head = tf.keras.layers.Dense(units=dataset_info['num_classes'],
                                            use_bias=True,
                                            kernel_initializer="zeros",
                                            trainable=True,
                                            name="head/dense")

        lr_supports = bit_hyperrule.get_schedule(dataset_info['num_examples'])

        schedule_length = lr_supports[-1]
        # NOTE: Let's not do that unless verified necessary and we do the same
        # across all three codebases.
        # schedule_length = schedule_length * 512 / args.batch

        optimizer = tf.keras.optimizers.SGD(momentum=0.9)
        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

        model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])

    logger.info(f'Fine-tuning the model...')
    steps_per_epoch = args.eval_every or schedule_length
    history = model.fit(
        data_train,
        steps_per_epoch=steps_per_epoch,
        epochs=schedule_length // steps_per_epoch,
        validation_data=data_test,  # here we are only using
        # this data to evaluate our performance
        callbacks=[BiTLRSched(args.base_lr, dataset_info['num_examples'])],
    )

    for epoch, accu in enumerate(history.history['val_accuracy']):
        logger.info(f'Step: {epoch * args.eval_every}, '
                    f'Test accuracy: {accu:0.3f}')
Beispiel #6
0
def mktrainval(args, logger):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)
    train_tx = tv.transforms.Compose([
        tv.transforms.Resize((precrop, precrop)),
        tv.transforms.RandomCrop((crop, crop)),
        tv.transforms.RandomHorizontalFlip(),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    val_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if args.dataset == "cifar10":
        train_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=train_tx,
                                        train=True,
                                        download=True)
        valid_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=val_tx,
                                        train=False,
                                        download=True)
    elif args.dataset == "cifar100":
        train_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=train_tx,
                                         train=True,
                                         download=True)
        valid_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=val_tx,
                                         train=False,
                                         download=True)
    elif args.dataset == "imagenet2012":
        train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"),
                                            train_tx)
        valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx)
    else:
        raise ValueError(f"Sorry, we have not spent time implementing the "
                         f"{args.dataset} dataset in the PyTorch codebase. "
                         f"In principle, it should be easy to add :)")

    if args.examples_per_class is not None:
        logger.info(
            f"Looking for {args.examples_per_class} images per class...")
        indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
        train_set = torch.utils.data.Subset(train_set, indices=indices)

    logger.info(f"Using a training set with {len(train_set)} images.")
    logger.info(f"Using a validation set with {len(valid_set)} images.")

    micro_batch_size = args.batch // args.batch_split

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    if micro_batch_size <= len(train_set):
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=micro_batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=False)
    else:
        # In the few-shot cases, the total dataset size might be smaller than the batch-size.
        # In these cases, the default sampler doesn't repeat, so we need to make it do that
        # if we want to match the behaviour from the paper.
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size,
            num_workers=args.workers,
            pin_memory=True,
            sampler=torch.utils.data.RandomSampler(
                train_set, replacement=True, num_samples=micro_batch_size))

    return train_set, valid_set, train_loader, valid_loader
Beispiel #7
0
def run():
    aicrowd_helpers.execution_start()

    #MAGIC HAPPENS BELOW
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0")
    assert torch.cuda.is_available()
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(
        'snakes_dataset')  # verify
    valid_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    given_df = pd.read_csv(AICROWD_TEST_METADATA_PATH)

    valid_set = SnakeDataset(AICROWD_TEST_IMAGES_PATH,
                             is_train=False,
                             transform=valid_tx,
                             target_transform=None,
                             csv_file=AICROWD_TEST_METADATA_PATH)  # verify
    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=32,
                                               shuffle=False,
                                               num_workers=0,
                                               pin_memory=True,
                                               drop_last=False)
    model = models.KNOWN_MODELS['BiT-M-R50x1'](
        head_size=len(VALID_SNAKE_SPECIES), zero_head=True)
    model = torch.nn.DataParallel(model)
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
    model_loc = pjoin('models', 'initial.pth.tar')
    checkpoint = torch.load(model_loc, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    model = model.to(device)
    model.eval()
    results = np.empty((0, 783), float)
    for b, (x, y) in enumerate(
            valid_loader):  #add name to dataset, y must be some random label
        with torch.no_grad():
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            logits = model(x)
            softmax_op = torch.nn.Softmax(dim=1)
            probs = softmax_op(logits)
            data_to_save = probs.data.cpu().numpy()
            results = np.concatenate((results, data_to_save), axis=0)
    filenames = given_df['hashed_id'].tolist()
    country_prob = pd.read_csv(
        pjoin('metadata', 'probability_of_species_per_country.csv'))
    country_name = country_prob[['Species/Country']]
    country_dict = {name[0]: i for i, name in enumerate(country_name.values)}
    given_country = given_df[['country']]
    country_list = []
    for country in given_country.values:
        country_list.append(str(country[0]).lower().replace(
            ' ', '-'))  # has to be a better way
    adjusted_results = []

    for i, result in enumerate(results):
        probs = result
        assert len(prob) == 783
        try:
            country_now = country_list[i]
            country_location = country_dict[country_now]
            country_prob_per_this_country = country_prob.loc[[
                country_location
            ]].values[0][1:]
            adjusted = country_prob_per_this_country * probs
            adjusted_results.append(adjusted)  # verify, we need list of list
        except:
            adjusted_results.append(probs)
    assert len(adjusted_results) == len(results)
    #normalize
    normalized_results = adjusted_results / adjusted_results.sum(axis=1)[:,
                                                                         None]

    df = pd.DataFrame(data=normalized_results,
                      index=filenames,
                      columns=VALID_SNAKE_SPECIES)
    df.index.name = 'hashed_id'
    pd.to_csv(AICROWD_PREDICTIONS_OUTPUT_PATH, index=True)

    aicrowd_helpers.execution_success(
        {"predictions_output_path": AICROWD_PREDICTIONS_OUTPUT_PATH})
Beispiel #8
0
def mktrainval(args, logger):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)
    train_tx = tv.transforms.Compose([
        tv.transforms.Resize((precrop, precrop)),
        tv.transforms.RandomCrop((crop, crop)),
        tv.transforms.RandomHorizontalFlip(),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    val_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if args.dataset == "cifar10":
        train_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=train_tx,
                                        train=True,
                                        download=True)
        valid_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=val_tx,
                                        train=False,
                                        download=True)
    elif args.dataset == "cifar100":
        train_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=train_tx,
                                         train=True,
                                         download=True)
        valid_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=val_tx,
                                         train=False,
                                         download=True)
    elif args.dataset == "imagenet2012":
        train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"),
                                            train_tx)
        valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx)
    # TODO: Define custom dataloading logic here for custom datasets
    elif args.dataset == "logo_2k":
        train_set = GetLoader(data_root='logo2k/Logo-2K+',
                              data_list='logo2k/train.txt',
                              label_dict='logo2k/logo2k_labeldict.pkl',
                              transform=train_tx)

        valid_set = GetLoader(data_root='logo2k/Logo-2K+',
                              data_list='logo2k/test.txt',
                              label_dict='logo2k/logo2k_labeldict.pkl',
                              transform=val_tx)

    elif args.dataset == "targetlist":
        train_set = GetLoader(data_root='../../phishpedia/expand_targetlist',
                              data_list='../train_targets.txt',
                              label_dict='../target_dict.json',
                              transform=train_tx)

        valid_set = GetLoader(data_root='../../phishpedia/expand_targetlist',
                              data_list='../test_targets.txt',
                              label_dict='../target_dict.json',
                              transform=val_tx)

    logger.info("Using a training set with {} images.".format(len(train_set)))
    logger.info("Using a validation set with {} images.".format(
        len(valid_set)))
    logger.info("Num of classes: {}".format(len(valid_set.classes)))

    micro_batch_size = args.batch // args.batch_split

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    if micro_batch_size <= len(train_set):
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=micro_batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=False)
    else:
        # In the few-shot cases, the total dataset size might be smaller than the batch-size.
        # In these cases, the default sampler doesn't repeat, so we need to make it do that
        # if we want to match the behaviour from the paper.
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size,
            num_workers=args.workers,
            pin_memory=True,
            sampler=torch.utils.data.RandomSampler(
                train_set, replacement=True, num_samples=micro_batch_size))

    return train_set, valid_set, train_loader, valid_loader
def get_data_loader(args):

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    torch.backends.cudnn.benchmark = True

    if args.dataset == "imagenet":

        train_dir = os.path.join(args.data_path, 'train')
        val_dir = os.path.join(args.data_path, 'val')
        dataset, dataset_test, train_sampler, test_sampler = load_data(
            train_dir, val_dir, args.cache_dataset, args.distributed)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=args.batch_size,
                                                  sampler=train_sampler,
                                                  num_workers=args.workers,
                                                  pin_memory=True)

        data_loader_test = torch.utils.data.DataLoader(
            dataset_test,
            batch_size=args.batch_size,
            sampler=test_sampler,
            num_workers=args.workers,
            pin_memory=True)

    elif args.dataset == "cifar10":
        if args.model != "big_transfer":
            mean = [0.4914, 0.4822, 0.4465]
            std = [0.2023, 0.1994, 0.2010]

            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
            dataset = CIFAR10(root=args.data_path,
                              train=True,
                              transform=transform_train)
            data_loader = DataLoader(dataset,
                                     batch_size=args.batch_size,
                                     num_workers=4,
                                     shuffle=True,
                                     drop_last=True,
                                     pin_memory=True)

            transform_val = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(mean, std)])
            dataset = CIFAR10(root=args.data_path,
                              train=False,
                              transform=transform_val)
            data_loader_test = DataLoader(dataset,
                                          batch_size=args.batch_size,
                                          num_workers=4,
                                          pin_memory=True)
        else:
            precrop, crop = bit_hyperrule.get_resolution_from_dataset(
                args.dataset)
            train_tx = transforms.Compose([
                transforms.Resize((precrop, precrop)),
                transforms.RandomCrop((crop, crop)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
            val_tx = transforms.Compose([
                transforms.Resize((crop, crop)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
            dataset = CIFAR10(root=args.data_path,
                              train=True,
                              transform=train_tx)
            data_loader = DataLoader(dataset,
                                     batch_size=args.batch_size,
                                     num_workers=4,
                                     shuffle=True,
                                     drop_last=True,
                                     pin_memory=True)
            dataset = CIFAR10(root=args.data_path,
                              train=False,
                              transform=val_tx)
            data_loader_test = DataLoader(dataset,
                                          batch_size=args.batch_size,
                                          num_workers=4,
                                          pin_memory=True)
    return data_loader, data_loader_test
Beispiel #10
0
def mktrainval(args, logger):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)
    train_tx = tv.transforms.Compose([
        tv.transforms.Resize((precrop, precrop)),
        tv.transforms.RandomCrop((crop, crop)),
        tv.transforms.RandomHorizontalFlip(),
        tv.transforms.RandomRotation(90),
        tv.transforms.ColorJitter(),
        tv.transforms.RandomAffine(0, scale=(1.0, 2.0), shear=20),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    val_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    path = args.datadir

    train_csv_file = pjoin(path, 'metadata', 'train_labels.csv')
    validate_csv_file = pjoin(path, 'metadata', 'validate_labels.csv')

    train_set = SnakeDataset(path,
                             is_train=True,
                             transform=train_tx,
                             target_transform=None,
                             csv_file=train_csv_file)
    valid_set = SnakeDataset(path,
                             is_train=False,
                             transform=val_tx,
                             target_transform=None,
                             csv_file=validate_csv_file)
    if args.examples_per_class is not None:
        logger.info(
            f"Looking for {args.examples_per_class} images per class...")
        indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
        train_set = torch.utils.data.Subset(train_set, indices=indices)

    logger.info(f"Using a training set with {len(train_set)} images.")
    logger.info(f"Using a validation set with {len(valid_set)} images.")

    micro_batch_size = args.batch // args.batch_split

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    if micro_batch_size <= len(train_set):
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=micro_batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=False)
    else:
        # In the few-shot cases, the total dataset size might be smaller than the batch-size.
        # In these cases, the default sampler doesn't repeat, so we need to make it do that
        # if we want to match the behaviour from the paper.
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size,
            num_workers=args.workers,
            pin_memory=True,
            sampler=torch.utils.data.RandomSampler(
                train_set, replacement=True, num_samples=micro_batch_size))

    return train_set, valid_set, train_loader, valid_loader
Beispiel #11
0
def main(args):
    logger = bit_common.setup_logger(args)

    logger.info(f'Available devices: {jax.devices()}')

    model = models.KNOWN_MODELS[args.model]

    # Load weigths of a BiT model
    bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.npz')
    if not os.path.exists(bit_model_file):
        raise FileNotFoundError(
            f'Model file is not found in "{args.bit_pretrained_dir}" directory.'
        )
    with open(bit_model_file, 'rb') as f:
        params_tf = np.load(f)
        params_tf = dict(zip(params_tf.keys(), params_tf.values()))

    resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset(
        args.dataset)

    # Setup input pipeline
    dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train',
                                                   args.examples_per_class)

    data_train = input_pipeline.get_data(
        dataset=args.dataset,
        mode='train',
        repeats=None,
        batch_size=args.batch,
        resize_size=resize_size,
        crop_size=crop_size,
        examples_per_class=args.examples_per_class,
        examples_per_class_seed=args.examples_per_class_seed,
        mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']),
        num_devices=jax.local_device_count(),
        tfds_manual_dir=args.tfds_manual_dir)
    logger.info(data_train)
    data_test = input_pipeline.get_data(dataset=args.dataset,
                                        mode='test',
                                        repeats=1,
                                        batch_size=args.batch_eval,
                                        resize_size=resize_size,
                                        crop_size=crop_size,
                                        examples_per_class=None,
                                        examples_per_class_seed=0,
                                        mixup_alpha=None,
                                        num_devices=jax.local_device_count(),
                                        tfds_manual_dir=args.tfds_manual_dir)
    logger.info(data_test)

    # Build ResNet architecture
    ResNet = model.partial(num_classes=dataset_info['num_classes'])
    _, params = ResNet.init_by_shape(
        jax.random.PRNGKey(0), [([1, crop_size, crop_size, 3], jnp.float32)])
    resnet_fn = ResNet.call

    # pmap replicates the models over all GPUs
    resnet_fn_repl = jax.pmap(ResNet.call)

    def cross_entropy_loss(*, logits, labels):
        logp = jax.nn.log_softmax(logits)
        return -jnp.mean(jnp.sum(logp * labels, axis=1))

    def loss_fn(params, images, labels):
        logits = resnet_fn(params, images)
        return cross_entropy_loss(logits=logits, labels=labels)

    # Update step, replicated over all GPUs
    @partial(jax.pmap, axis_name='batch')
    def update_fn(opt, lr, batch):
        l, g = jax.value_and_grad(loss_fn)(opt.target, batch['image'],
                                           batch['label'])
        g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
        opt = opt.apply_gradient(g, learning_rate=lr)
        return opt

    # In-place update of randomly initialized weights by BiT weigths
    tf2jax.transform_params(params,
                            params_tf,
                            num_classes=dataset_info['num_classes'])

    # Create optimizer and replicate it over all GPUs
    opt = optim.Momentum(beta=0.9).create(params)
    opt_repl = flax_utils.replicate(opt)

    # Delete referenes to the objects that are not needed anymore
    del opt
    del params

    total_steps = bit_hyperrule.get_schedule(dataset_info['num_examples'])[-1]

    # Run training loop
    for step, batch in zip(range(1, total_steps + 1),
                           data_train.as_numpy_iterator()):
        lr = bit_hyperrule.get_lr(step - 1, dataset_info['num_examples'],
                                  args.base_lr)
        opt_repl = update_fn(opt_repl, flax_utils.replicate(lr), batch)

        # Run eval step
        if ((args.eval_every and step % args.eval_every == 0)
                or (step == total_steps)):

            accuracy_test = np.mean([
                c for batch in data_test.as_numpy_iterator()
                for c in (np.argmax(
                    resnet_fn_repl(opt_repl.target, batch['image']), axis=2) ==
                          np.argmax(batch['label'], axis=2)).ravel()
            ])

            logger.info(f'Step: {step}, '
                        f'learning rate: {lr:.07f}, '
                        f'Test accuracy: {accuracy_test:0.3f}')
Beispiel #12
0
def select_worst_images(args, model, full_train_loader, device):
    print("Selecting images for next epoch training...")
    model.eval()

    gts = []
    paths = []
    losses = []

    micro_batch_size = args.batch_size // args.batch_split
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)

    if args.input_channels == 3:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    elif args.input_channels == 2:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)),
        ])

    elif args.input_channels == 1:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5), (0.5)),
        ])

    pbar = enumerate(full_train_loader)
    pbar = tqdm.tqdm(pbar, total=len(full_train_loader))

    for b, (path, x, y) in pbar:
        with torch.no_grad():
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # compute output, measure accuracy and record loss.
            logits = model(x)

            paths.extend(path)
            gts.extend(y.cpu().numpy())

            c = torch.nn.CrossEntropyLoss(reduction='none')(logits, y)

            losses.extend(
                c.cpu().numpy().tolist())  # Also ensures a sync point.

        # measure elapsed time
        end = time.time()

    gts = np.array(gts)
    losses = np.array(losses)
    losses[np.argsort(losses)[int(losses.shape[0] *
                                  (1.0 - args.noise)):]] = 0.0  #

    #paths_ = np.array(paths)[np.where(losses > np.median(losses))[0]]
    #gts_   = gts[np.where(losses > np.median(losses))[0]]

    selection_idx = int(args.data_fraction * losses.shape[0])
    paths_ = np.array(paths)[np.argsort(losses)[-selection_idx:]]
    gts_ = gts[np.argsort(losses)[-selection_idx:]]

    smart_train_set = ImageFolder(paths_, gts_, train_tx, crop)

    smart_train_loader = torch.utils.data.DataLoader(
        smart_train_set,
        batch_size=micro_batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False)

    return smart_train_set, smart_train_loader
Beispiel #13
0
def _mktrainval(args, logger):
  """Returns train and validation datasets."""
  precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)
  if args.test_run: # save memory
    precrop, crop = 64, 56

  train_tx = tv.transforms.Compose([
      tv.transforms.Resize((precrop, precrop)),
      tv.transforms.RandomCrop((crop, crop)),
      tv.transforms.RandomHorizontalFlip(),
      tv.transforms.ToTensor(),
      tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  ])
  val_tx = tv.transforms.Compose([
      tv.transforms.Resize((crop, crop)),
      tv.transforms.ToTensor(),
      tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  ])

  collate_fn = None
  n_train = None
  micro_batch_size = args.batch // args.batch_split
  if args.dataset == "cifar10":
    train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True)
    valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True)
  elif args.dataset == "cifar100":
    train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True)
    valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True)
  elif args.dataset == "imagenet2012":
    train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), transform=train_tx)
    valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), transform=val_tx)
  elif args.dataset.startswith('objectnet') or args.dataset.startswith('imageneta'): # objectnet and objectnet_bbox and objectnet_no_bbox
    identifier = 'objectnet' if args.dataset.startswith('objectnet') else 'imageneta'
    valid_set = tv.datasets.ImageFolder(f"../datasets/{identifier}/", transform=val_tx)

    if args.inpaint == 'none':
      if args.dataset == 'objectnet' or args.dataset == 'imageneta':
        train_set = tv.datasets.ImageFolder(pjoin(args.datadir, f"train_{args.dataset}"),
                                            transform=train_tx)
      else: # For only images with or w/o bounding box
        train_bbox_file = '../datasets/imagenet/LOC_train_solution_size.csv'
        df = pd.read_csv(train_bbox_file)
        filenames = set(df[df.bbox_ratio <= args.bbox_max_ratio].ImageId)
        if args.dataset == f"{identifier}_no_bbox":
          is_valid_file = lambda path: os.path.basename(path).split('.')[0] not in filenames
        elif args.dataset == f"{identifier}_bbox":
          is_valid_file = lambda path: os.path.basename(path).split('.')[0] in filenames
        else:
          raise NotImplementedError()

        train_set = tv.datasets.ImageFolder(
          pjoin(args.datadir, f"train_{identifier}"),
          is_valid_file=is_valid_file,
          transform=train_tx)
    else: # do inpainting
      train_tx = tv.transforms.Compose([
        data_utils.Resize((precrop, precrop)),
        data_utils.RandomCrop((crop, crop)),
        data_utils.RandomHorizontalFlip(),
        data_utils.ToTensor(),
        data_utils.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
      ])

      train_set = ImagenetBoundingBoxFolder(
        root=f"../datasets/imagenet/train_{identifier}",
        bbox_file='../datasets/imagenet/LOC_train_solution.csv',
        transform=train_tx)
      collate_fn = bbox_collate
      n_train = len(train_set) * 2
      micro_batch_size //= 2

  else:
    raise ValueError(f"Sorry, we have not spent time implementing the "
                     f"{args.dataset} dataset in the PyTorch codebase. "
                     f"In principle, it should be easy to add :)")

  if args.examples_per_class is not None:
    logger.info(f"Looking for {args.examples_per_class} images per class...")
    indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
    train_set = torch.utils.data.Subset(train_set, indices=indices)

  logger.info(f"Using a training set with {len(train_set)} images.")
  logger.info(f"Using a validation set with {len(valid_set)} images.")

  valid_loader = torch.utils.data.DataLoader(
      valid_set, batch_size=micro_batch_size, shuffle=False,
      num_workers=args.workers, pin_memory=True, drop_last=False)

  if micro_batch_size <= len(train_set):
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=micro_batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=False,
        collate_fn=collate_fn)
  else:
    # In the few-shot cases, the total dataset size might be smaller than the batch-size.
    # In these cases, the default sampler doesn't repeat, so we need to make it do that
    # if we want to match the behaviour from the paper.
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True,
        sampler=torch.utils.data.RandomSampler(train_set, replacement=True, num_samples=micro_batch_size),
        collate_fn=collate_fn)

  if n_train is None:
    n_train = len(train_set)
  return n_train, len(valid_set.classes), train_loader, valid_loader