def instantiate_transforms(cfg: DictConfig, global_config: DictConfig = None):
    "loades in individual transformations"
    if cfg._target_ == "aa":
        img_size_min = global_config.input.input_size
        aa_params = dict(
            translate_const=int(img_size_min * 0.45),
            img_mean=tuple(
                [min(255, round(255 * x)) for x in global_config.input.mean]),
        )

        if (global_config.input.interpolation
                and global_config.input.interpolation != "random"):
            aa_params["interpolation"] = _pil_interp(
                global_config.input.interpolation)

        # Load autoaugment transformations
        if cfg.policy.startswith("rand"):
            return rand_augment_transform(cfg.policy, aa_params)
        elif cfg.policy.startswith("augmix"):
            aa_params["translate_pct"] = 0.3
            return augment_and_mix_transform(cfg.policy, aa_params)
        else:
            return auto_augment_transform(cfg.policy, aa_params)

    else:
        return instantiate(cfg)
Ejemplo n.º 2
0
def transforms_imagenet_train(
    img_size=224,
    scale=(0.08, 1.0),
    color_jitter=0.4,
    auto_augment=None,
    interpolation='random',
    use_prefetcher=False,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    re_prob=0.,
    re_mode='const',
    re_count=1,
    re_num_splits=0,
    separate=False,
    squish=False,
    do_8_rotations=False,
):
    """
    If separate==True, the transforms are returned as a tuple of 3 separate transforms
    for use in a mixing dataset that passes
     * all data through the first (primary) transform, called the 'clean' data
     * a portion of the data through the secondary transform
     * normalizes and converts the branches above with the third, final transform
    """
    if squish:
        if not isinstance(img_size, tuple):
            img_size = (img_size, img_size)
        resize = transforms.Resize(img_size, _pil_interp('bilinear'))
    else:
        resize = RandomResizedCropAndInterpolation(img_size,
                                                   scale=scale,
                                                   interpolation=interpolation)

    if do_8_rotations:
        primary_tfl = [resize, RandomRotation()]
    else:
        primary_tfl = [resize, transforms.RandomHorizontalFlip()]

    secondary_tfl = []
    if auto_augment:
        assert isinstance(auto_augment, str)
        if isinstance(img_size, tuple):
            img_size_min = min(img_size)
        else:
            img_size_min = img_size
        aa_params = dict(
            translate_const=int(img_size_min * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        )
        if interpolation and interpolation != 'random':
            aa_params['interpolation'] = _pil_interp(interpolation)
        if auto_augment.startswith('rand'):
            secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
        elif auto_augment.startswith('augmix'):
            aa_params['translate_pct'] = 0.3
            secondary_tfl += [
                augment_and_mix_transform(auto_augment, aa_params)
            ]
        else:
            secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
    elif color_jitter is not None:
        # color jitter is enabled when not using AA
        if isinstance(color_jitter, (list, tuple)):
            # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
            # or 4 if also augmenting hue
            assert len(color_jitter) in (3, 4)
        else:
            # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
            color_jitter = (float(color_jitter), ) * 3
        secondary_tfl += [transforms.ColorJitter(*color_jitter)]

    final_tfl = []
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        final_tfl += [ToNumpy()]
    else:
        final_tfl += [
            transforms.ToTensor(),
            transforms.Normalize(mean=torch.tensor(mean),
                                 std=torch.tensor(std))
        ]
        if re_prob > 0.:
            final_tfl.append(
                RandomErasing(re_prob,
                              mode=re_mode,
                              max_count=re_count,
                              num_splits=re_num_splits,
                              device='cpu'))

    if separate:
        return transforms.Compose(primary_tfl), transforms.Compose(
            secondary_tfl), transforms.Compose(final_tfl)
    else:
        return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
Ejemplo n.º 3
0
def main():
    args = parse_args()
    # Create a pytorch dataset
    data_dir = pathlib.Path('./tiny-imagenet-200/')
    image_count = len(list(data_dir.glob('**/*.JPEG')))
    CLASS_NAMES = np.array(
        [item.name for item in (data_dir / 'train').glob('*')])
    print('Discovered {} images'.format(image_count))

    # Create the training data generator
    batch_size = 32
    im_height = 64
    im_width = 64

    if args.model == "cait_m48_448":
        im_height = 448
        im_width = 448
    else:
        im_height = 224
        im_width = 224

    basic_transforms = [
        transforms.Resize((im_height, im_width)),
        transforms.RandomCrop(im_height, padding=8)
    ]
    augmix = []
    if args.augmix:
        augmix = [augment_and_mix_transform("augmix-m3-w3", {})]
    other_transforms = [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    data_transforms = transforms.Compose(basic_transforms + augmix +
                                         other_transforms)

    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_set = torchvision.datasets.ImageFolder(data_dir / 'train',
                                                 data_transforms)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.cuda.device("cuda:0")
    device = "cuda:0"
    num_epochs = args.num_epochs

    if args.model in model_to_arch:
        model = timm.create_model(model_to_arch[args.model], pretrained=True)
    else:
        print("model does not exist")

    # Create a simple model
    for param in list(model.parameters())[:args.num_tune_layers]:
        param.requires_grad = False

    # Parameters of newly constructed modules have requires_grad=True by default
    if args.model == "inception_resnet_v2":
        num_ftrs = model.classif.in_features
        model.classif = nn.Sequential(nn.Dropout(0.4),
                                      nn.Linear(num_ftrs, 1024), nn.ReLU(),
                                      nn.Linear(1024, 256), nn.ReLU(),
                                      nn.Linear(256, 200))
        optim = torch.optim.Adam(
            [{
                "params": list(
                    model.parameters())[-1 * args.num_tune_layers:-6],
                "lr": 1e-4
            }, {
                "params": model.classif.parameters(),
                "lr": 1e-3
            }],
            weight_decay=1e-5)
    elif args.model == "pit":
        num_ftrs = model.head.in_features
        if args.sparse_attn_k:
            for transformer in model.transformers:
                for block in transformer.blocks:
                    block.attn = JankAttention(block.attn, args.sparse_attn_k)


#         if args.residual_attn:
#             for transformer in mode
        model.head = nn.Sequential(nn.Dropout(0.4), nn.Linear(num_ftrs, 1024),
                                   nn.ReLU(), nn.Linear(1024, 256), nn.ReLU(),
                                   nn.Linear(256, 200))
        model.head_dist = nn.Sequential(nn.Dropout(0.4),
                                        nn.Linear(num_ftrs, 1024), nn.ReLU(),
                                        nn.Linear(1024, 256), nn.ReLU(),
                                        nn.Linear(256, 200))
        optim = torch.optim.Adam(
            [{
                "params": list(
                    model.parameters())[-1 * args.num_tune_layers:-15],
                "lr": 1e-4
            }, {
                "params": model.head.parameters(),
                "lr": 1e-3
            }, {
                "params": model.head_dist.parameters(),
                "lr": 1e-3
            }],
            weight_decay=1e-5)
    elif args.model == "vit":
        num_ftrs = model.head.in_features
        if args.sparse_attn_k:
            for block in model.blocks:
                block.attn = JankAttention(block.attn, args.sparse_attn_k)
        model.head = nn.Sequential(nn.Dropout(0.4), nn.Linear(num_ftrs, 1024),
                                   nn.ReLU(), nn.Linear(1024, 256), nn.ReLU(),
                                   nn.Linear(256, 200))

        optim = torch.optim.Adam(
            [{
                "params": list(
                    model.parameters())[-1 * args.num_tune_layers:-8],
                "lr": 1e-4
            }, {
                "params": model.head.parameters(),
                "lr": 1e-3
            }],
            weight_decay=1e-5)
    if args.checkpoint:
        checkpoint = torch.load(args.output_dir +
                                "/epoch{}".format(args.start_epoch - 1))
        model.load_state_dict(checkpoint['net'])
    print("num params: {}".format(len(list(model.parameters()))))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, num_epochs)
    criterion = nn.CrossEntropyLoss()
    model = model.to(device)
    for i in range(args.start_epoch, num_epochs):
        train_total, train_correct = 0, 0
        model.train()
        print("training epoch {}".format(i + 1))
        for idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optim.zero_grad()
            outputs = model(inputs)
            if (len(outputs)) == 2:
                loss = criterion(outputs[1], targets)
                loss.backward(retain_graph=True)
                outputs = outputs[0]
            loss = criterion(outputs, targets)
            loss.backward()
            optim.step()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
            if idx % 100 == 0:
                print("\r", end='')
                print(
                    f'training {100 * idx / len(train_loader):.2f}%: {train_correct / train_total:.3f}',
                    end='')
        scheduler.step()
        torch.save({
            'net': model.state_dict(),
        }, args.output_dir + "/epoch{}".format(i))

        validation_set = ValidationSet(data_dir / 'val', transform_test)
        val_loader = torch.utils.data.DataLoader(validation_set,
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=4,
                                                 pin_memory=True)

        model.eval()
        all_preds = []
        all_labels = []
        all_losses = []
        with torch.no_grad():
            index = 0
            print("\r evaluating validation set after epoch: {}".format(i))
            for batch in val_loader:
                inputs = batch[0]
                targets = batch[1]
                targets = targets.cuda()
                inputs = inputs.cuda()
                preds = model(inputs)
                loss = nn.CrossEntropyLoss()(preds, targets)
                all_losses.append(loss.cpu())
                all_preds.append(preds.cpu())
                all_labels.append(targets.cpu())
        top_preds = [x.argsort(dim=-1)[:, -1:].squeeze() for x in all_preds]
        correct = 0
        for idx, batch_preds in enumerate(top_preds):
            correct += torch.eq(all_labels[idx], batch_preds).sum()
        accuracy = correct.item() / (32 * len(all_labels))
        print(f"Epoch {i} Top 1 Validation Accuracy: {accuracy}")

        top_preds = [x.argsort(dim=-1)[:, -3:] for x in all_preds]
        correct = 0
        for idx, batch_preds in enumerate(top_preds):
            correct += torch.eq(all_labels[idx],
                                batch_preds[:, 0:1].squeeze()).sum()
            correct += torch.eq(all_labels[idx],
                                batch_preds[:, 1:2].squeeze()).sum()

            correct += torch.eq(all_labels[idx],
                                batch_preds[:, 2:3].squeeze()).sum()

        accuracy = correct.item() / (32 * len(all_labels))
        print(f"Epoch {i} top 3 Validation Accuracy: {accuracy}")