def instantiate_transforms(cfg: DictConfig, global_config: DictConfig = None): "loades in individual transformations" if cfg._target_ == "aa": img_size_min = global_config.input.input_size aa_params = dict( translate_const=int(img_size_min * 0.45), img_mean=tuple( [min(255, round(255 * x)) for x in global_config.input.mean]), ) if (global_config.input.interpolation and global_config.input.interpolation != "random"): aa_params["interpolation"] = _pil_interp( global_config.input.interpolation) # Load autoaugment transformations if cfg.policy.startswith("rand"): return rand_augment_transform(cfg.policy, aa_params) elif cfg.policy.startswith("augmix"): aa_params["translate_pct"] = 0.3 return augment_and_mix_transform(cfg.policy, aa_params) else: return auto_augment_transform(cfg.policy, aa_params) else: return instantiate(cfg)
def transforms_imagenet_train( img_size=224, scale=(0.08, 1.0), color_jitter=0.4, auto_augment=None, interpolation='random', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, re_prob=0., re_mode='const', re_count=1, re_num_splits=0, separate=False, squish=False, do_8_rotations=False, ): """ If separate==True, the transforms are returned as a tuple of 3 separate transforms for use in a mixing dataset that passes * all data through the first (primary) transform, called the 'clean' data * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ if squish: if not isinstance(img_size, tuple): img_size = (img_size, img_size) resize = transforms.Resize(img_size, _pil_interp('bilinear')) else: resize = RandomResizedCropAndInterpolation(img_size, scale=scale, interpolation=interpolation) if do_8_rotations: primary_tfl = [resize, RandomRotation()] else: primary_tfl = [resize, transforms.RandomHorizontalFlip()] secondary_tfl = [] if auto_augment: assert isinstance(auto_augment, str) if isinstance(img_size, tuple): img_size_min = min(img_size) else: img_size_min = img_size aa_params = dict( translate_const=int(img_size_min * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in mean]), ) if interpolation and interpolation != 'random': aa_params['interpolation'] = _pil_interp(interpolation) if auto_augment.startswith('rand'): secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] elif auto_augment.startswith('augmix'): aa_params['translate_pct'] = 0.3 secondary_tfl += [ augment_and_mix_transform(auto_augment, aa_params) ] else: secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] elif color_jitter is not None: # color jitter is enabled when not using AA if isinstance(color_jitter, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue assert len(color_jitter) in (3, 4) else: # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue color_jitter = (float(color_jitter), ) * 3 secondary_tfl += [transforms.ColorJitter(*color_jitter)] final_tfl = [] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm final_tfl += [ToNumpy()] else: final_tfl += [ transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)) ] if re_prob > 0.: final_tfl.append( RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) if separate: return transforms.Compose(primary_tfl), transforms.Compose( secondary_tfl), transforms.Compose(final_tfl) else: return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
def main(): args = parse_args() # Create a pytorch dataset data_dir = pathlib.Path('./tiny-imagenet-200/') image_count = len(list(data_dir.glob('**/*.JPEG'))) CLASS_NAMES = np.array( [item.name for item in (data_dir / 'train').glob('*')]) print('Discovered {} images'.format(image_count)) # Create the training data generator batch_size = 32 im_height = 64 im_width = 64 if args.model == "cait_m48_448": im_height = 448 im_width = 448 else: im_height = 224 im_width = 224 basic_transforms = [ transforms.Resize((im_height, im_width)), transforms.RandomCrop(im_height, padding=8) ] augmix = [] if args.augmix: augmix = [augment_and_mix_transform("augmix-m3-w3", {})] other_transforms = [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] data_transforms = transforms.Compose(basic_transforms + augmix + other_transforms) transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_set = torchvision.datasets.ImageFolder(data_dir / 'train', data_transforms) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.cuda.device("cuda:0") device = "cuda:0" num_epochs = args.num_epochs if args.model in model_to_arch: model = timm.create_model(model_to_arch[args.model], pretrained=True) else: print("model does not exist") # Create a simple model for param in list(model.parameters())[:args.num_tune_layers]: param.requires_grad = False # Parameters of newly constructed modules have requires_grad=True by default if args.model == "inception_resnet_v2": num_ftrs = model.classif.in_features model.classif = nn.Sequential(nn.Dropout(0.4), nn.Linear(num_ftrs, 1024), nn.ReLU(), nn.Linear(1024, 256), nn.ReLU(), nn.Linear(256, 200)) optim = torch.optim.Adam( [{ "params": list( model.parameters())[-1 * args.num_tune_layers:-6], "lr": 1e-4 }, { "params": model.classif.parameters(), "lr": 1e-3 }], weight_decay=1e-5) elif args.model == "pit": num_ftrs = model.head.in_features if args.sparse_attn_k: for transformer in model.transformers: for block in transformer.blocks: block.attn = JankAttention(block.attn, args.sparse_attn_k) # if args.residual_attn: # for transformer in mode model.head = nn.Sequential(nn.Dropout(0.4), nn.Linear(num_ftrs, 1024), nn.ReLU(), nn.Linear(1024, 256), nn.ReLU(), nn.Linear(256, 200)) model.head_dist = nn.Sequential(nn.Dropout(0.4), nn.Linear(num_ftrs, 1024), nn.ReLU(), nn.Linear(1024, 256), nn.ReLU(), nn.Linear(256, 200)) optim = torch.optim.Adam( [{ "params": list( model.parameters())[-1 * args.num_tune_layers:-15], "lr": 1e-4 }, { "params": model.head.parameters(), "lr": 1e-3 }, { "params": model.head_dist.parameters(), "lr": 1e-3 }], weight_decay=1e-5) elif args.model == "vit": num_ftrs = model.head.in_features if args.sparse_attn_k: for block in model.blocks: block.attn = JankAttention(block.attn, args.sparse_attn_k) model.head = nn.Sequential(nn.Dropout(0.4), nn.Linear(num_ftrs, 1024), nn.ReLU(), nn.Linear(1024, 256), nn.ReLU(), nn.Linear(256, 200)) optim = torch.optim.Adam( [{ "params": list( model.parameters())[-1 * args.num_tune_layers:-8], "lr": 1e-4 }, { "params": model.head.parameters(), "lr": 1e-3 }], weight_decay=1e-5) if args.checkpoint: checkpoint = torch.load(args.output_dir + "/epoch{}".format(args.start_epoch - 1)) model.load_state_dict(checkpoint['net']) print("num params: {}".format(len(list(model.parameters())))) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, num_epochs) criterion = nn.CrossEntropyLoss() model = model.to(device) for i in range(args.start_epoch, num_epochs): train_total, train_correct = 0, 0 model.train() print("training epoch {}".format(i + 1)) for idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device) optim.zero_grad() outputs = model(inputs) if (len(outputs)) == 2: loss = criterion(outputs[1], targets) loss.backward(retain_graph=True) outputs = outputs[0] loss = criterion(outputs, targets) loss.backward() optim.step() _, predicted = outputs.max(1) train_total += targets.size(0) train_correct += predicted.eq(targets).sum().item() if idx % 100 == 0: print("\r", end='') print( f'training {100 * idx / len(train_loader):.2f}%: {train_correct / train_total:.3f}', end='') scheduler.step() torch.save({ 'net': model.state_dict(), }, args.output_dir + "/epoch{}".format(i)) validation_set = ValidationSet(data_dir / 'val', transform_test) val_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) model.eval() all_preds = [] all_labels = [] all_losses = [] with torch.no_grad(): index = 0 print("\r evaluating validation set after epoch: {}".format(i)) for batch in val_loader: inputs = batch[0] targets = batch[1] targets = targets.cuda() inputs = inputs.cuda() preds = model(inputs) loss = nn.CrossEntropyLoss()(preds, targets) all_losses.append(loss.cpu()) all_preds.append(preds.cpu()) all_labels.append(targets.cpu()) top_preds = [x.argsort(dim=-1)[:, -1:].squeeze() for x in all_preds] correct = 0 for idx, batch_preds in enumerate(top_preds): correct += torch.eq(all_labels[idx], batch_preds).sum() accuracy = correct.item() / (32 * len(all_labels)) print(f"Epoch {i} Top 1 Validation Accuracy: {accuracy}") top_preds = [x.argsort(dim=-1)[:, -3:] for x in all_preds] correct = 0 for idx, batch_preds in enumerate(top_preds): correct += torch.eq(all_labels[idx], batch_preds[:, 0:1].squeeze()).sum() correct += torch.eq(all_labels[idx], batch_preds[:, 1:2].squeeze()).sum() correct += torch.eq(all_labels[idx], batch_preds[:, 2:3].squeeze()).sum() accuracy = correct.item() / (32 * len(all_labels)) print(f"Epoch {i} top 3 Validation Accuracy: {accuracy}")