def test_lukasz_marcin_approach_backward_pass(self): img = torch.randn(size=(2, 3, 10, 10)) r1 = cond_resnet18(norm_layer=None) r2 = cond_resnet18(norm_layer=None) r2.load_state_dict(r1.state_dict()) assert_equal(r1.fc.weight, r2.fc.weight) r1.eval() r2.eval() ks = list(range(1, 64, 6)) loss_fn = CrossEntropyLoss() y_true = torch.tensor([1, 2]) # marcin approach loss_1 = 0 for k in ks: y_pred = r1(img, full_k=k) l = loss_fn(y_pred, y_true) loss_1 += l # lukasz approach loss_2 = 0 _, intermediate_full = r2(img, return_intermediate=True) emb_full = intermediate_full["embedding"] for k in ks: mask = torch.ones_like(emb_full) k_out = k * 8 mask[torch.arange(len(mask)), k_out:] = 0 emb_mask = emb_full * mask y_pred = r2.fc(emb_mask) l = loss_fn(y_pred, y_true) loss_2 += l assert_equal(loss_1, loss_2) loss_1.backward() loss_2.backward() for ((n1, p1), (n2, p2)) in zip(r1.named_parameters(), r2.named_parameters()): self.assertEqual(n1, n2) assert_equal(p1, p2, msg=n1) if p1.grad is None: self.assertIsNone(p2.grad) else: assert_close(p1.grad, p2.grad, msg=n1)
def test_main_fc_ks(self): img = torch.randn(size=(1, 3, 10, 10)) main_fc_ks = [1, 8, 10, 24, 32] r = cond_resnet18() r.eval() _, intermediate_full = r(img, return_intermediate=True, main_fc_ks=main_fc_ks) for n in main_fc_ks: out = r(img, full_k=n) assert_equal(out, intermediate_full[MAIN_FC_KS][n])
def test_fc_for_channels(self): img = torch.randn(size=(1, 3, 10, 10)) fc_for_channels = [1, 8, 10, 24, 32] r = cond_resnet18(fc_for_channels=fc_for_channels) r.eval() n_channels = 64 _, intermediate_full = r(img, return_intermediate=True) for n in range(1, n_channels + 1): _, intermediate_zero = r(img, full_k=n, return_intermediate=True) for k, v in intermediate_zero[FC_FOR_CHANNELS].items(): self.assertLessEqual(k, n) assert_equal(intermediate_full[FC_FOR_CHANNELS][k], v)
def test_resnet18(self): img = torch.randn(size=(1, 3, 10, 10)) r = cond_resnet18() r.eval() n_channels = 64 _, intermediate_full = r(img, return_intermediate=True) for n in range(1, n_channels + 1): _, intermediate_zero = r(img, k=n, return_intermediate=True) for (k, v) in intermediate_zero.items(): intermediate_n_channels = v.shape[1] self.assertEqual(intermediate_n_channels // n_channels, intermediate_n_channels / n_channels) chan_mult = intermediate_n_channels // n_channels n_out = n * chan_mult assert_equal(v[0, n_out:], torch.zeros_like(v[0, n_out:])) assert_equal(v[0, :n_out], intermediate_full[k][0, :n_out])
def test_lukasz_marcin_approach_forward_pass(self): """Test if passing image through resnet with K channels is the same as zeroing out K*8 channels before the final FC layer""" img = torch.randn(size=(1, 3, 10, 10)) r = cond_resnet18() r.eval() n_channels = 64 _, intermediate_full = r(img, return_intermediate=True) emb_full = intermediate_full["embedding"] for k in range(1, n_channels + 1): k_cls, k_intermediate = r(img, k=k, return_intermediate=True) emb_k = k_intermediate["embedding"] mask = torch.ones_like(emb_full) k_out = k * 8 mask[0, k_out:] = 0 emb_mask = emb_full * mask assert_equal(emb_mask, emb_k) assert_equal(k_cls, r.fc(emb_k)) assert_equal(k_cls, r.fc(emb_mask))
def main(): parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--epochs', default=25, type=int, help='number of epochs') parser.add_argument('--bs', default=128, type=int, help='batch size') parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--data_root', default='./data', help='Directory where results will be save') parser.add_argument('--save_dir', default='./results', help='Directory where results will be save') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') parser.add_argument("--model_type", default="classic", choices=["classic", "conditional"]) parser.add_argument("--inplanes", default=64, type=int, help="ResNet inplanes") parser.add_argument( '--conditional', '-c', nargs='+', type=float, default=None, help="Train separate ResNet heads for specific numbers of channels." "If not specified, one robust head will be trained") parser.add_argument( '--k', default=None, type=int, help='Number of condition samples for one robust head.') parser.add_argument('--scheduler', choices=['None', 'CyclicLR', 'LambdaLR'], default='CyclicLR') parser.add_argument( '--model', choices=[ 'ResNet18', # 'ResNet34', 'ResNet50' ], default='ResNet18') args = parser.parse_args() if args.conditional is not None: assert np.prod(args.conditional) > 0 args.conditional.sort() device = 'cuda' if torch.cuda.is_available() else 'cpu' best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch # Data print('==> Preparing data..') transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root=args.data_root, train=True, download=False, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=4) testset = torchvision.datasets.CIFAR10(root=args.data_root, train=False, download=False, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=4) # Model print('==> Building model..') model_str = f"{args.model_type}_{args.model}-{args.inplanes}_inplanes" if args.model_type == "classic": assert args.conditional is None and args.k is None net = resnet18(in_planes=args.inplanes, num_classes=10) else: if args.conditional is not None: assert args.k is None if any([i < 1 for i in args.conditional]): args.conditional = [ ceil(args.inplanes * i) for i in args.conditional ] else: args.conditional = [int(i) for i in args.conditional] model_str = f"{model_str}-{len(args.conditional)}_heads" net = cond_resnet18(in_planes=args.inplanes, fc_for_channels=args.conditional) else: assert args.k is not None model_str = f"{model_str}-robust_head_{args.k}_samples" net = cond_resnet18(in_planes=args.inplanes) print(f'\033[0;1;33m{model_str}\033[0m') args.save_dir = f"{args.save_dir}/{model_str}" net = net.to(device) if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert Path( args.resume).is_file(), 'Error: resume file does not exist!' checkpoint = torch.load(args.resume) net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=0) if args.scheduler == 'CyclicLR': scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer, base_lr=1e-4, max_lr=1.2 * args.lr, step_size_up=10, mode="exp_range", scale_mode='cycle', cycle_momentum=False, gamma=(1e-4 / args.lr)**(1 / (0.9 * args.epochs))) elif args.scheduler == 'LambdaLR': gamma = (1e-4 / args.lr)**(1 / (0.9 * args.epochs)) labda = lambda epoch: gamma**epoch if epoch < 0.9 * args.epochs else 1e-4 / args.lr scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=labda) else: scheduler = None args.save_dir = f"{args.save_dir}/{datetime.now().strftime('%Y-%m-%d_%H%M%S')}" args.path_checkpoint = f'{args.save_dir}/checkpoint' args.path_logs = f'{args.save_dir}/logs' print('\033[0;32m', '=' * 15, ' Parameters ', '=' * 15, '\033[0m', sep='') for arg in vars(args): print(f'\033[0;32m{arg}: {getattr(args, arg)}\033[0m') print('\033[0;32m', '=' * 46, '\033[0m', sep='') Path(args.path_checkpoint).mkdir(exist_ok=True, parents=True) with open(f"{args.save_dir}/config.yaml", 'w') as yaml_file: yaml.safe_dump(vars(args), yaml_file, default_style=None, default_flow_style=False, sort_keys=False) writer = SummaryWriter(args.path_logs) epoch_tqdm = tqdm(range(start_epoch, start_epoch + args.epochs), desc="Training") for epoch in epoch_tqdm: ############ # Training # ############ net.train() total = correct = train_loss = 0 for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader), total=len(trainloader), leave=False): inputs, targets = inputs.to(device), targets.to(device) if args.model_type in ["classic"]: outputs = net(inputs) loss = criterion(outputs, targets) else: if args.conditional is None: samples = np.arange(1, args.inplanes) np.random.shuffle(samples) samples = samples[:args.k] samples = list(samples) + [args.inplanes] loss = 0 _, inter = net(inputs, return_intermediate=True, main_fc_ks=samples) for n_channels, outputs in inter[MAIN_FC_KS].items(): loss += criterion(outputs, targets) else: _, inter = net(inputs, return_intermediate=True) loss = 0 for n_channels, outputs in inter[FC_FOR_CHANNELS].items(): loss += criterion(outputs, targets) optimizer.zero_grad(True) loss.backward() optimizer.step() with torch.no_grad(): # outputs = outputs if args.conditional is None else outputs[-1] train_loss += loss.item() _, predicted = outputs.max(1) correct += predicted.eq(targets).sum().item() total += targets.size(0) # ===================logger======================== writer.add_scalar('train/loss', loss.item(), epoch * len(trainloader) + batch_idx) acc_train = correct / total train_loss /= len(trainloader) ############ # Testing # ############ net.eval() total = 0 # correct = 0 if args.conditional is None else [0. for _ in args.conditional] # test_loss = 0 if args.conditional is None else [0. for _ in args.conditional] correct = defaultdict(float) test_loss = defaultdict(float) with torch.no_grad(): for batch_idx, (inputs, targets) in tqdm(enumerate(testloader), total=len(testloader), leave=False): inputs, targets = inputs.to(device), targets.to(device) if args.model_type in ["classic"]: outputs = net(inputs) loss = criterion(outputs, targets) # ===================logger======================== writer.add_scalar(f'test_loss/{args.k}', loss.item(), epoch * len(testloader) + batch_idx) test_loss[args.inplanes] += loss.item() _, predicted = outputs.max(1) correct[args.inplanes] += predicted.eq( targets).sum().item() else: if args.conditional is None: ks_to_check = sorted( set( list(range(1, args.inplanes, 8)) + [args.inplanes])) _, inter = net(inputs, return_intermediate=True, main_fc_ks=ks_to_check) out_key = MAIN_FC_KS else: _, inter = net(inputs, return_intermediate=True) out_key = FC_FOR_CHANNELS for n_channels, outputs in inter[out_key].items(): loss = criterion(outputs, targets) test_loss[n_channels] += loss.item() _, predicted = outputs.max(1) correct[n_channels] += predicted.eq( targets).sum().item() writer.add_scalar(f'test_loss/{n_channels}', loss.item(), epoch * len(testloader) + batch_idx) total += targets.size(0) # ===================logger======================== writer.add_scalar('train/loss_per_epoch', train_loss, epoch) writer.add_scalar('train/acc_per_epoch', acc_train, epoch) for n_channels, v in correct.items(): acc_test = correct[n_channels] / total test_loss[n_channels] /= len(testloader) writer.add_scalar(f'test_loss_per_epoch/{n_channels}', test_loss[n_channels], epoch) writer.add_scalar(f'test_acc_per_epoch/{n_channels}', acc_test, epoch) acc_test = correct[max(correct.keys())] / total test_loss = sum(test_loss.values()) ############################ # Save model # ############################ if acc_test > best_acc: state = { 'net': net.state_dict(), 'acc': acc_test, 'epoch': epoch, } torch.save(state, f'{args.save_dir}/checkpoint/model.pth') best_acc = acc_test ############################ # scheduler # ############################ if scheduler is None: writer.add_scalar('scheduler', args.lr, epoch) else: scheduler.step() writer.add_scalar('scheduler', scheduler.get_last_lr()[0], epoch) epoch_tqdm.set_description( f"Train: loss={train_loss:.4f}, acc={acc_train:.4f}, " f"Test: loss={test_loss:.4f}, acc={acc_test:.4f}")