def get_omniglot(ways, shots): omniglot = l2l.vision.datasets.FullOmniglot(root='~/data', transform=transforms.Compose([ transforms.Resize( 28, interpolation=LANCZOS), transforms.ToTensor(), lambda x: 1.0 - x, ]), download=True) dataset = l2l.data.MetaDataset(omniglot) classes = list(range(1623)) random.shuffle(classes) train_transforms = [ FilterLabels(dataset, classes[:1100]), NWays(dataset, ways), KShots(dataset, 2 * shots), LoadData(dataset), RemapLabels(dataset), ConsecutiveLabels(dataset), RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0]) ] train_tasks = l2l.data.TaskDataset(dataset, task_transforms=train_transforms, num_tasks=20000) valid_transforms = [ FilterLabels(dataset, classes[1100:1200]), NWays(dataset, ways), KShots(dataset, 2 * shots), LoadData(dataset), RemapLabels(dataset), ConsecutiveLabels(dataset), RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0]) ] valid_tasks = l2l.data.TaskDataset(dataset, task_transforms=valid_transforms, num_tasks=1024) test_transforms = [ FilterLabels(dataset, classes[1200:]), NWays(dataset, ways), KShots(dataset, 2 * shots), LoadData(dataset), RemapLabels(dataset), ConsecutiveLabels(dataset), RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0]) ] test_tasks = l2l.data.TaskDataset(dataset, task_transforms=test_transforms, num_tasks=1024) return train_tasks, valid_tasks, test_tasks
def fc100_tasksets( train_ways=5, train_samples=10, test_ways=5, test_samples=10, root='~/data', **kwargs, ): """Tasksets for FC100 benchmarks.""" data_transform = tv.transforms.ToTensor() train_dataset = l2l.vision.datasets.FC100(root=root, transform=data_transform, mode='train', download=True) valid_dataset = l2l.vision.datasets.FC100(root=root, transform=data_transform, mode='validation', download=True) test_dataset = l2l.vision.datasets.FC100(root=root, transform=data_transform, mode='test', download=True) train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, train_ways), KShots(train_dataset, train_samples), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] valid_transforms = [ NWays(valid_dataset, test_ways), KShots(valid_dataset, test_samples), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] test_transforms = [ NWays(test_dataset, test_ways), KShots(test_dataset, test_samples), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] _datasets = (train_dataset, valid_dataset, test_dataset) _transforms = (train_transforms, valid_transforms, test_transforms) return _datasets, _transforms
def get_mini_imagenet(ways, shots): # Create Datasets train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='train', download=True) valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='validation', download=True) test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test', download=True) train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, ways), KShots(train_dataset, 2 * shots), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms, num_tasks=20000) valid_transforms = [ NWays(valid_dataset, ways), KShots(valid_dataset, 2 * shots), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600) test_transforms = [ NWays(test_dataset, ways), KShots(test_dataset, 2 * shots), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600) return train_tasks, valid_tasks, test_tasks
def mini_imagenet_tasksets( train_ways=5, train_samples=10, test_ways=5, test_samples=10, root='~/data', **kwargs, ): """Tasksets for mini-ImageNet benchmarks.""" train_dataset = l2l.vision.datasets.MiniImagenet(root=root, mode='train', download=True) valid_dataset = l2l.vision.datasets.MiniImagenet(root=root, mode='validation', download=True) test_dataset = l2l.vision.datasets.MiniImagenet(root=root, mode='test', download=True) train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, train_ways), KShots(train_dataset, train_samples), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] valid_transforms = [ NWays(valid_dataset, test_ways), KShots(valid_dataset, test_samples), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] test_transforms = [ NWays(test_dataset, test_ways), KShots(test_dataset, test_samples), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] _datasets = (train_dataset, valid_dataset, test_dataset) _transforms = (train_transforms, valid_transforms, test_transforms) return _datasets, _transforms
def test_instanciation(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset(dataset, task_transforms=[LoadData(dataset)], num_tasks=NUM_TASKS) self.assertEqual(len(task_dataset), NUM_TASKS)
def create_task_pool(dataset=None, num_tasks=100, ways=5, shot=1): dataset = l2l.data.MetaDataset(dataset) transforms = [ NWays(dataset, ways), KShots(dataset, shot), LoadData(dataset), ] task_pool = l2l.data.TaskDataset(dataset, task_transforms=transforms, num_tasks=num_tasks) return task_pool
def test_load_data(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) task_dataset = TaskDataset(dataset, task_transforms=[LoadData(dataset)], num_tasks=NUM_TASKS) for task in task_dataset: self.assertTrue(isinstance(task[0], torch.Tensor)) self.assertTrue(isinstance(task[1], torch.Tensor))
def test_infinite_tasks(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset( dataset, task_transforms=[LoadData(dataset), random_subset]) self.assertEqual(len(task_dataset), 1) prev = task_dataset.sample() for i, task in enumerate(task_dataset): self.assertFalse(task_equal(prev, task)) prev = task if i > 4: break
def test_filter_labels(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) chosen_labels = random.sample(list(range(Y_SHAPE)), k=Y_SHAPE // 2) dataset = MetaDataset(TensorDataset(data, labels)) task_dataset = TaskDataset(dataset, task_transforms=[ FilterLabels(dataset, chosen_labels), LoadData(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: for label in task[1]: self.assertTrue(label in chosen_labels)
def test_n_ways(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for ways in range(1, 10): task_dataset = TaskDataset( dataset, task_transforms=[NWays(dataset, n=ways), LoadData(dataset)], num_tasks=NUM_TASKS) for task in task_dataset: bins = task[1].bincount() num_classes = len(bins) - (bins == 0).sum() self.assertEqual(num_classes, ways)
def test_remap_labels(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for ways in range(1, 5): task_dataset = TaskDataset(dataset, task_transforms=[ NWays(dataset, ways), LoadData(dataset), RemapLabels(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: for label in range(ways): self.assertTrue(label in task[1])
def test_dataloader(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset( dataset, task_transforms=[LoadData(dataset), random_subset], num_tasks=NUM_TASKS) task_loader = DataLoader(task_dataset, shuffle=True, batch_size=META_BSZ, num_workers=WORKERS, drop_last=True) for task_batch in task_loader: self.assertEqual(task_batch[0].shape, (META_BSZ, X_SHAPE)) self.assertEqual(task_batch[1].shape, (META_BSZ, 1))
def test_task_transforms(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset( dataset, task_transforms=[LoadData(dataset), random_subset], num_tasks=NUM_TASKS) for task in task_dataset: # Tests transforms on the task_description self.assertEqual(len(task[0]), SUBSET_SIZE) self.assertEqual(len(task[1]), SUBSET_SIZE) # Tests transforms on the data self.assertEqual(task[0].size(1), X_SHAPE) self.assertLessEqual(task[1].max(), Y_SHAPE - 1) self.assertGreaterEqual(task[1].max(), 0)
def test_k_shots(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for replacement in [False, True]: for shots in range(1, 10): task_dataset = TaskDataset(dataset, task_transforms=[ KShots(dataset, k=shots, replacement=replacement), LoadData(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: bins = task[1].bincount() correct = (bins == shots).sum() self.assertEqual(correct, Y_SHAPE)
def test_task_caching(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset(dataset, task_transforms=[LoadData(dataset)], num_tasks=NUM_TASKS) tasks = [] for i, task in enumerate(task_dataset, 1): tasks.append(task) self.assertEqual(i, NUM_TASKS) for ref, task in zip(tasks, task_dataset): self.assertTrue(task_equal(ref, task)) for i in range(NUM_TASKS): ref = tasks[i] task = task_dataset[i] self.assertTrue(task_equal(ref, task))
def main( ways=5, shots=5, meta_lr=0.002, fast_lr=0.1, # original 0.1 reg_lambda=0, adapt_steps=5, # original: 5 meta_bsz=32, iters=1000, # orginal: 1000 cuda=1, seed=42, ): cuda = bool(cuda) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device = torch.device('cpu') if cuda and torch.cuda.device_count(): torch.cuda.manual_seed(seed) device = torch.device('cuda') train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='train') valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='validation') test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test') train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, ways), KShots(train_dataset, 2 * shots), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms, num_tasks=20000) valid_transforms = [ NWays(valid_dataset, ways), KShots(valid_dataset, 2 * shots), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600) test_transforms = [ NWays(test_dataset, ways), KShots(test_dataset, 2 * shots), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600) # Create model # features = l2l.vision.models.MiniImagenetCNN(ways) features = l2l.vision.models.ConvBase(output_size=32, channels=3, max_pool=True) # for p in features.parameters(): # print(p.shape) features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, 1600))) features.to(device) head = torch.nn.Linear(1600, ways) head = l2l.algorithms.MAML(head, lr=fast_lr) head.to(device) # Setup optimization all_parameters = list(features.parameters()) # optimizer = torch.optim.Adam(all_parameters, lr=meta_lr) ## use different learning rates for w and theta optimizer = torch.optim.Adam(all_parameters, lr=meta_lr) loss = nn.CrossEntropyLoss(reduction='mean') training_accuracy = torch.ones(iters) test_accuracy = torch.ones(iters) running_time = np.ones(iters) import time start_time = time.time() for iteration in range(iters): optimizer.zero_grad() meta_train_error = 0.0 meta_train_accuracy = 0.0 meta_valid_error = 0.0 meta_valid_accuracy = 0.0 meta_test_error = 0.0 meta_test_accuracy = 0.0 for task in range(meta_bsz): # Compute meta-training loss learner = head.clone() batch = train_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, features, loss, reg_lambda, adapt_steps, shots, ways, device) evaluation_error.backward() meta_train_error += evaluation_error.item() meta_train_accuracy += evaluation_accuracy.item() # Compute meta-validation loss learner = head.clone() batch = valid_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, features, loss, reg_lambda, adapt_steps, shots, ways, device) meta_valid_error += evaluation_error.item() meta_valid_accuracy += evaluation_accuracy.item() # Compute meta-testing loss learner = head.clone() batch = test_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, features, loss, reg_lambda, adapt_steps, shots, ways, device) meta_test_error += evaluation_error.item() meta_test_accuracy += evaluation_accuracy.item() training_accuracy[iteration] = meta_train_accuracy / meta_bsz test_accuracy[iteration] = meta_test_accuracy / meta_bsz # Print some metrics print('\n') print('Iteration', iteration) print('Meta Train Error', meta_train_error / meta_bsz) print('Meta Train Accuracy', meta_train_accuracy / meta_bsz) print('Meta Valid Error', meta_valid_error / meta_bsz) print('Meta Valid Accuracy', meta_valid_accuracy / meta_bsz) print('Meta Test Error', meta_test_error / meta_bsz) print('Meta Test Accuracy', meta_test_accuracy / meta_bsz) # Average the accumulated gradients and optimize for p in all_parameters: p.grad.data.mul_(1.0 / meta_bsz) # print('head') # for p in list(head.parameters()): # print(torch.max(torch.abs(p.grad.data))) # print('feature') # for p in list(features.parameters()): # print(torch.max(torch.abs(p.grad.data))) optimizer.step() end_time = time.time() running_time[iteration] = end_time - start_time print('total running time', end_time - start_time) return training_accuracy.numpy(), test_accuracy.numpy(), running_time
def main( ways=5, shots=5, meta_lr=0.001, fast_lr=0.1, # original 0.1 reg_lambda=0, adapt_steps=5, # original: 5 meta_bsz=32, iters=1000, # orginal: 1000 cuda=1, seed=42, ): print('hlr=' + str(meta_lr) + ' flr=' + str(fast_lr) + ' reg=' + str(reg_lambda)) cuda = bool(cuda) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device = torch.device('cpu') if cuda and torch.cuda.device_count(): torch.cuda.manual_seed(seed) device = torch.device('cuda') # Create Datasets train_dataset = l2l.vision.datasets.FC100( root='~/data', transform=tv.transforms.ToTensor(), mode='train') valid_dataset = l2l.vision.datasets.FC100( root='~/data', transform=tv.transforms.ToTensor(), mode='validation') test_dataset = l2l.vision.datasets.FC100( root='~/data', transform=tv.transforms.ToTensor(), mode='test') train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ FusedNWaysKShots(train_dataset, n=ways, k=2 * shots), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms, num_tasks=20000) valid_transforms = [ FusedNWaysKShots(valid_dataset, n=ways, k=2 * shots), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600) test_transforms = [ FusedNWaysKShots(test_dataset, n=ways, k=2 * shots), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600) # Create model features = l2l.vision.models.ConvBase(output_size=64, channels=3, max_pool=True) features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, 256))) features.to(device) head_dim = 256 # Setup optimization all_parameters = list(features.parameters()) optimizer = torch.optim.Adam(all_parameters, lr=meta_lr) # optimizer = torch.optim.SGD(all_parameters, lr=meta_lr) loss = nn.CrossEntropyLoss(reduction='mean') training_accuracy = torch.ones(iters) test_accuracy = torch.ones(iters) running_time = np.ones(iters) import time start_time = time.time() for iteration in range(iters): optimizer.zero_grad() meta_train_error = 0.0 meta_train_accuracy = 0.0 meta_valid_error = 0.0 meta_valid_accuracy = 0.0 meta_test_error = 0.0 meta_test_accuracy = 0.0 for task in range(meta_bsz): # Compute meta-training loss batch = train_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, head_dim, features, loss, fast_lr, reg_lambda, adapt_steps, shots, ways, device) evaluation_error.backward() meta_train_error += evaluation_error.item() meta_train_accuracy += evaluation_accuracy.item() # Compute meta-validation loss batch = valid_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, head_dim, features, loss, fast_lr, reg_lambda, adapt_steps, shots, ways, device) meta_valid_error += evaluation_error.item() meta_valid_accuracy += evaluation_accuracy.item() # Compute meta-testing loss batch = test_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, head_dim, features, loss, fast_lr, reg_lambda, adapt_steps, shots, ways, device) meta_test_error += evaluation_error.item() meta_test_accuracy += evaluation_accuracy.item() training_accuracy[iteration] = meta_train_accuracy / meta_bsz test_accuracy[iteration] = meta_test_accuracy / meta_bsz # Print some metrics print('\n') print('Iteration', iteration) print('Meta Train Error', meta_train_error / meta_bsz) print('Meta Train Accuracy', meta_train_accuracy / meta_bsz) print('Meta Valid Error', meta_valid_error / meta_bsz) print('Meta Valid Accuracy', meta_valid_accuracy / meta_bsz) print('Meta Test Error', meta_test_error / meta_bsz) print('Meta Test Accuracy', meta_test_accuracy / meta_bsz) # Average the accumulated gradients and optimize for p in all_parameters: p.grad.data.mul_(1.0 / meta_bsz) optimizer.step() end_time = time.time() running_time[iteration] = end_time - start_time print('time per iteration', end_time - start_time) return training_accuracy.numpy(), test_accuracy.numpy(), running_time
def cifarfs_tasksets( train_ways=5, train_samples=10, test_ways=5, test_samples=10, root='~/data', device=None, **kwargs, ): """Tasksets for CIFAR-FS benchmarks.""" data_transform = tv.transforms.ToTensor() train_dataset = l2l.vision.datasets.CIFARFS(root=root, transform=data_transform, mode='train', download=True) valid_dataset = l2l.vision.datasets.CIFARFS(root=root, transform=data_transform, mode='validation', download=True) test_dataset = l2l.vision.datasets.CIFARFS(root=root, transform=data_transform, mode='test', download=True) if device is not None: train_dataset = l2l.data.OnDeviceDataset( dataset=train_dataset, device=device, ) valid_dataset = l2l.data.OnDeviceDataset( dataset=valid_dataset, device=device, ) test_dataset = l2l.data.OnDeviceDataset( dataset=test_dataset, device=device, ) train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, train_ways), KShots(train_dataset, train_samples), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] valid_transforms = [ NWays(valid_dataset, test_ways), KShots(valid_dataset, test_samples), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] test_transforms = [ NWays(test_dataset, test_ways), KShots(test_dataset, test_samples), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] _datasets = (train_dataset, valid_dataset, test_dataset) _transforms = (train_transforms, valid_transforms, test_transforms) return _datasets, _transforms
meta_curetsr_lvl5 = l2l.data.MetaDataset(curetsr_lvl5) train_dataset = meta_curetsr_lvl0 valid_dataset = meta_curetsr_lvl0 test_dataset = meta_curetsr_lvl5 classes = list(range(14)) # 14 classes of stop signs random.shuffle(classes) # Changes, end! train_dataset = l2l.data.MetaDataset(train_dataset) train_transforms = [ FilterLabels(train_dataset, classes[:8]), NWays(train_dataset, args.train_way), KShots(train_dataset, args.train_query + args.shot), LoadData(train_dataset), RemapLabels(train_dataset), ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms) train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True) valid_dataset = l2l.data.MetaDataset(valid_dataset) valid_transforms = [ FilterLabels(valid_dataset, classes[8:14]), NWays(valid_dataset, args.test_way), KShots(valid_dataset, args.test_query + args.test_shot), LoadData(valid_dataset), RemapLabels(valid_dataset), ] valid_tasks = l2l.data.TaskDataset(valid_dataset,
def train(args): logger = SummaryWriter(comment=args.comment) device = torch.device('cpu') if args.gpu and torch.cuda.device_count(): print("Using gpu") torch.cuda.manual_seed(43) device = torch.device('cuda') if args.backbone2d: model = Convnet2D(st_attention=args.st_attention) else: model = Convnet(st_attention=args.st_attention) model.to(device) relation_head = RelationHead(ways=args.train_way, dynamic=args.dynhead) relation_head.to(device) if args.temporal_align: temporal_align = TemporalAlignMoudle(10, shot=args.shot) # d=torch.load('./ckpt/pretrain_ta/ckpt.pth','cpu') # temporal_align.load_state_dict(d['align_dict']) temporal_align.to(device) # del d # train_augmentation = get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) train_augmentation = Compose( [GroupScale([128, 128]), GroupRandomHorizontalFlip(is_flow=False)]) normalize = IdentityTransform() num_class, args.train_list, args.val_list, args.root_path, prefix, anno_prefix = dataset_config.return_dataset( args.dataset, 'RGB') args.test_list = args.train_list.replace('train', 'test') path_data = args.root_path num_segments = 20 train_dataset = TSNDataSet(path_data, args.train_list, num_segments=num_segments, new_length=1, modality='RGB', image_tmpl=prefix, transform=Compose([ train_augmentation, StackBatch(roll=False), To3DTorchFormatTensor(div=True) ]), dense_sample=False) train_dataset = l2l.data.MetaDataset( train_dataset, indices_to_labels=train_dataset.indices_to_labels) train_transforms = [ NWays(train_dataset, args.train_way), KShots(train_dataset, args.train_query + args.shot), LoadData(train_dataset), RemapLabels(train_dataset), ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms) # train_loader = DataLoader(train_tasks,num_workers=1, pin_memory=True, shuffle=True) train_prefetcher = DataPrefetcher(train_tasks) valid_dataset = TSNDataSet(path_data, args.val_list, num_segments=num_segments, new_length=1, modality='RGB', image_tmpl=prefix, transform=Compose([ GroupScale([128, 128]), StackBatch(roll=False), To3DTorchFormatTensor(div=True) ]), dense_sample=False) valid_dataset = l2l.data.MetaDataset( valid_dataset, indices_to_labels=valid_dataset.indices_to_labels) valid_transforms = [ NWays(valid_dataset, args.test_way), KShots(valid_dataset, args.test_query + args.test_shot), LoadData(valid_dataset), RemapLabels(valid_dataset), ] valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=100) valid_loader = DataLoader(valid_tasks, num_workers=1, pin_memory=True, shuffle=True) # valid_prefetcher=DataPrefetcher(valid_tasks) models = [model, relation_head] param_groups = [{'params': m.parameters()} for m in models] if args.temporal_align: models.append(temporal_align) param_groups.append({ 'params': temporal_align.parameters(), 'lr': 1e-3 }) optimizer = torch.optim.Adam(param_groups, lr=1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) print('Start training') best_metric = 0 for epoch in range(1, args.max_epoch + 1): for m in models: m.train() loss_ctr = 0 n_loss = 0 n_acc = 0 batch = train_prefetcher.next() for i in range(100): # batch = next(iter(train_loader)) loss, acc = fast_adapt(models, batch, args.train_way, args.shot, args.train_query, metric=pairwise_distances_logits, device=device, epoch=epoch) loss_ctr += 1 n_loss += loss.item() n_acc += acc.item() batch = train_prefetcher.next() optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % 10 == 0: sys.stdout.write('\rWorking.... i=%d \033[K' % (i + 1)) lr_scheduler.step() print('epoch {}, train, loss={:.4f} acc={:.4f}'.format( epoch, n_loss / loss_ctr, n_acc / loss_ctr)) logger.add_scalar('train_acc', n_acc / loss_ctr, global_step=epoch) ckpt = { 'epoch': epoch + 1, 'arch': 'raw_relation', 'state_dict': model.state_dict(), 'head_dict': relation_head.state_dict(), 'optimizer': optimizer.state_dict(), 'best_acc': -1, } if args.temporal_align: ckpt['align_dict'] = temporal_align.state_dict() save_checkpoint(ckpt, False) for m in models: m.eval() loss_ctr = 0 n_loss = 0 n_acc = 0 # batch=valid_prefetcher.next() with torch.no_grad(): for i, batch in enumerate(valid_loader): # while batch is not None and batch[0] is not None: loss, acc = fast_adapt(models, batch, args.test_way, args.test_shot, args.test_query, metric=pairwise_distances_logits, device=device, epoch=epoch) loss_ctr += 1 n_loss += loss.item() n_acc += acc.item() # batch=valid_prefetcher.next() if (i + 1) % 10 == 0: sys.stdout.write('\rEvaling.... i=%d \033[K' % (i + 1)) print('epoch {}, val, loss={:.4f} acc={:.4f}'.format( epoch, n_loss / loss_ctr, n_acc / loss_ctr)) logger.add_scalar('val_acc', n_acc / loss_ctr, global_step=epoch) metric_for_best = n_acc / loss_ctr is_best = metric_for_best > best_metric best_metric = max(metric_for_best, best_metric) ckpt = { 'epoch': epoch + 1, 'arch': 'raw_relation', 'state_dict': model.state_dict(), 'head_dict': relation_head.state_dict(), 'optimizer': optimizer.state_dict(), 'best_acc': best_metric, } if args.temporal_align: ckpt['align_dict'] = temporal_align.state_dict() save_checkpoint(ckpt, is_best) torch.cuda.empty_cache()
def test(args): device = torch.device('cpu') if args.gpu and torch.cuda.device_count(): print("Using gpu") torch.cuda.manual_seed(43) device = torch.device('cuda') if args.backbone2d: model = Convnet2D(st_attention=args.st_attention) else: model = Convnet(st_attention=args.st_attention) model.to(device) relation_head = RelationHead(ways=args.test_way, dynamic=args.dynhead) relation_head.to(device) if args.temporal_align: temporal_align = TemporalAlignMoudle(10, shot=args.test_shot) temporal_align.to(device) models = [model, relation_head] if args.temporal_align: models.append(temporal_align) num_class, args.train_list, args.val_list, args.root_path, prefix, anno_prefix = dataset_config.return_dataset( args.dataset, 'RGB') args.test_list = args.train_list.replace('train', 'test') path_data = args.root_path num_segments = 20 test_dataset = TSNDataSet(path_data, args.test_list, num_segments=num_segments, new_length=1, modality='RGB', image_tmpl=prefix, transform=Compose([ GroupScale([128, 128]), StackBatch(roll=False), To3DTorchFormatTensor(div=True) ]), dense_sample=False, test_mode=True) test_dataset = l2l.data.MetaDataset( test_dataset, indices_to_labels=test_dataset.indices_to_labels) test_transforms = [ NWays(test_dataset, args.test_way), KShots(test_dataset, args.test_query + args.test_shot), LoadData(test_dataset), RemapLabels(test_dataset), ] test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=2000) test_loader = DataLoader(test_tasks, num_workers=1, pin_memory=True, shuffle=True) # test_prefetcher=DataPrefetcher(test_tasks) loss_ctr = 0 n_acc = 0 model.load_state_dict( torch.load('%s/%s/ckpt.pth' % (args.root_model, args.store_name))['state_dict']) relation_head.load_state_dict( torch.load('%s/%s/ckpt.pth' % (args.root_model, args.store_name))['head_dict']) if args.temporal_align: temporal_align.load_state_dict( torch.load('%s/%s/ckpt.pth' % (args.root_model, args.store_name))['align_dict']) for m in models: m.eval() for i, batch in enumerate(test_loader, 1): loss, acc = fast_adapt(models, batch, args.test_way, args.test_shot, args.test_query, metric=pairwise_distances_logits, device=device) loss_ctr += 1 n_acc += acc.item() sys.stdout.write('\rbatch {}: {:.2f}({:.2f}) \033[K'.format( i, n_acc / loss_ctr * 100, acc * 100)) print()
def main( ways=5, shots=5, meta_lr=0.001, fast_lr=0.1, adapt_steps=5, meta_bsz=32, iters=1000, cuda=1, seed=42, ): cuda = bool(cuda) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device = torch.device('cpu') if cuda and torch.cuda.device_count(): torch.cuda.manual_seed(seed) device = torch.device('cuda') # Create Datasets train_dataset = l2l.vision.datasets.FC100( root='~/data', transform=tv.transforms.ToTensor(), mode='train') valid_dataset = l2l.vision.datasets.FC100( root='~/data', transform=tv.transforms.ToTensor(), mode='validation') test_dataset = l2l.vision.datasets.FC100( root='~/data', transform=tv.transforms.ToTensor(), mode='test') train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ FusedNWaysKShots(train_dataset, n=ways, k=2 * shots), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms, num_tasks=20000) valid_transforms = [ FusedNWaysKShots(valid_dataset, n=ways, k=2 * shots), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600) test_transforms = [ FusedNWaysKShots(test_dataset, n=ways, k=2 * shots), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600) # Create model features = l2l.vision.models.ConvBase(output_size=64, channels=3, max_pool=True) features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, 256))) features.to(device) head = torch.nn.Linear(256, ways) head = l2l.algorithms.MAML(head, lr=fast_lr) head.to(device) # Setup optimization all_parameters = list(features.parameters()) + list(head.parameters()) optimizer = torch.optim.Adam(all_parameters, lr=meta_lr) loss = nn.CrossEntropyLoss(reduction='mean') for iteration in range(iters): optimizer.zero_grad() meta_train_error = 0.0 meta_train_accuracy = 0.0 meta_valid_error = 0.0 meta_valid_accuracy = 0.0 meta_test_error = 0.0 meta_test_accuracy = 0.0 for task in range(meta_bsz): # Compute meta-training loss learner = head.clone() batch = train_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, features, loss, adapt_steps, shots, ways, device) evaluation_error.backward() meta_train_error += evaluation_error.item() meta_train_accuracy += evaluation_accuracy.item() # Compute meta-validation loss learner = head.clone() batch = valid_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, features, loss, adapt_steps, shots, ways, device) meta_valid_error += evaluation_error.item() meta_valid_accuracy += evaluation_accuracy.item() # Compute meta-testing loss learner = head.clone() batch = test_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, features, loss, adapt_steps, shots, ways, device) meta_test_error += evaluation_error.item() meta_test_accuracy += evaluation_accuracy.item() # Print some metrics print('\n') print('Iteration', iteration) print('Meta Train Error', meta_train_error / meta_bsz) print('Meta Train Accuracy', meta_train_accuracy / meta_bsz) print('Meta Valid Error', meta_valid_error / meta_bsz) print('Meta Valid Accuracy', meta_valid_accuracy / meta_bsz) print('Meta Test Error', meta_test_error / meta_bsz) print('Meta Test Accuracy', meta_test_accuracy / meta_bsz) # Average the accumulated gradients and optimize for p in all_parameters: p.grad.data.mul_(1.0 / meta_bsz) optimizer.step()
def main( ways=5, train_shots=15, test_shots=5, meta_lr=1.0, meta_mom=0.0, meta_bsz=5, fast_lr=0.001, train_bsz=10, test_bsz=15, train_adapt_steps=8, test_adapt_steps=50, num_iterations=100000, test_interval=100, adam=0, # Use adam or sgd for fast-adapt meta_decay=1, # Linearly decay the meta-lr or not cuda=1, seed=42, ): cuda = bool(cuda) use_adam = bool(adam) meta_decay = bool(meta_decay) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device = torch.device('cpu') if cuda and torch.cuda.device_count(): torch.cuda.manual_seed(seed) device = torch.device('cuda') # Create Datasets train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='train') valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='validation') test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test') train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, ways), KShots(train_dataset, 2 * train_shots), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms, num_tasks=20000) valid_transforms = [ NWays(valid_dataset, ways), KShots(valid_dataset, 2 * test_shots), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600) test_transforms = [ NWays(test_dataset, ways), KShots(test_dataset, 2 * test_shots), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600) # Create model model = l2l.vision.models.MiniImagenetCNN(ways) model.to(device) if use_adam: opt = optim.Adam(model.parameters(), meta_lr, betas=(meta_mom, 0.999)) else: opt = optim.SGD(model.parameters(), lr=meta_lr, momentum=meta_mom) adapt_opt = optim.Adam(model.parameters(), lr=fast_lr, betas=(0, 0.999)) adapt_opt_state = adapt_opt.state_dict() loss = nn.CrossEntropyLoss(reduction='mean') for iteration in range(num_iterations): # anneal meta-lr if meta_decay: frac_done = float(iteration) / num_iterations new_lr = frac_done * meta_lr + (1 - frac_done) * meta_lr for pg in opt.param_groups: pg['lr'] = new_lr # zero-grad the parameters for p in model.parameters(): p.grad = torch.zeros_like(p.data) meta_train_error = 0.0 meta_train_accuracy = 0.0 meta_valid_error = 0.0 meta_valid_accuracy = 0.0 meta_test_error = 0.0 meta_test_accuracy = 0.0 for task in range(meta_bsz): # Compute meta-training loss learner = deepcopy(model) adapt_opt = optim.Adam(learner.parameters(), lr=fast_lr, betas=(0, 0.999)) adapt_opt.load_state_dict(adapt_opt_state) batch = train_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, fast_lr, loss, adapt_steps=train_adapt_steps, batch_size=train_bsz, opt=adapt_opt, shots=train_shots, ways=ways, device=device) adapt_opt_state = adapt_opt.state_dict() for p, l in zip(model.parameters(), learner.parameters()): p.grad.data.add_(-1.0, l.data) meta_train_error += evaluation_error.item() meta_train_accuracy += evaluation_accuracy.item() if iteration % test_interval == 0: # Compute meta-validation loss learner = deepcopy(model) adapt_opt = optim.Adam(learner.parameters(), lr=fast_lr, betas=(0, 0.999)) adapt_opt.load_state_dict(adapt_opt_state) batch = valid_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, fast_lr, loss, adapt_steps=test_adapt_steps, batch_size=test_bsz, opt=adapt_opt, shots=test_shots, ways=ways, device=device) meta_valid_error += evaluation_error.item() meta_valid_accuracy += evaluation_accuracy.item() # Compute meta-testing loss learner = deepcopy(model) adapt_opt = optim.Adam(learner.parameters(), lr=fast_lr, betas=(0, 0.999)) adapt_opt.load_state_dict(adapt_opt_state) batch = test_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, fast_lr, loss, adapt_steps=test_adapt_steps, batch_size=test_bsz, opt=adapt_opt, shots=test_shots, ways=ways, device=device) meta_test_error += evaluation_error.item() meta_test_accuracy += evaluation_accuracy.item() # Print some metrics print('\n') print('Iteration', iteration) print('Meta Train Error', meta_train_error / meta_bsz) print('Meta Train Accuracy', meta_train_accuracy / meta_bsz) if iteration % test_interval == 0: print('Meta Valid Error', meta_valid_error / meta_bsz) print('Meta Valid Accuracy', meta_valid_accuracy / meta_bsz) print('Meta Test Error', meta_test_error / meta_bsz) print('Meta Test Accuracy', meta_test_accuracy / meta_bsz) # Average the accumulated gradients and optimize for p in model.parameters(): p.grad.data.mul_(1.0 / meta_bsz).add_(p.data) opt.step()
def main( ways=5, shots=5, meta_lr=0.003, fast_lr=0.5, meta_batch_size=32, adaptation_steps=1, num_iterations=60000, cuda=True, seed=42, ): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device = torch.device('cpu') if cuda and torch.cuda.device_count(): torch.cuda.manual_seed(seed) device = torch.device('cuda') # Create Datasets train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='train') valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='validation') test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test') train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, ways), KShots(train_dataset, 2 * shots), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms, num_tasks=20000) valid_transforms = [ NWays(valid_dataset, ways), KShots(valid_dataset, 2 * shots), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600) test_transforms = [ NWays(test_dataset, ways), KShots(test_dataset, 2 * shots), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600) # Create model model = l2l.vision.models.MiniImagenetCNN(ways) model.to(device) maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False) opt = optim.Adam(maml.parameters(), meta_lr) loss = nn.CrossEntropyLoss(reduction='mean') training_accuracy = torch.ones(num_iterations) test_accuracy = torch.ones(num_iterations) running_time = np.ones(num_iterations) import time start_time = time.time() for iteration in range(num_iterations): opt.zero_grad() meta_train_error = 0.0 meta_train_accuracy = 0.0 meta_valid_error = 0.0 meta_valid_accuracy = 0.0 meta_test_error = 0.0 meta_test_accuracy = 0.0 for task in range(meta_batch_size): # Compute meta-training loss learner = maml.clone() batch = train_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, loss, adaptation_steps, shots, ways, device) evaluation_error.backward() meta_train_error += evaluation_error.item() meta_train_accuracy += evaluation_accuracy.item() # Compute meta-validation loss learner = maml.clone() batch = valid_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, loss, adaptation_steps, shots, ways, device) meta_valid_error += evaluation_error.item() meta_valid_accuracy += evaluation_accuracy.item() # Compute meta-test loss learner = maml.clone() batch = test_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, loss, adaptation_steps, shots, ways, device) meta_test_error += evaluation_error.item() meta_test_accuracy += evaluation_accuracy.item() training_accuracy[iteration] = meta_train_accuracy / meta_batch_size test_accuracy[iteration] = meta_test_accuracy / meta_batch_size # Print some metrics print('\n') print('Iteration', iteration) print('Meta Train Error', meta_train_error / meta_batch_size) print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size) print('Meta Valid Error', meta_valid_error / meta_batch_size) print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size) print('Meta Test Error', meta_test_error / meta_batch_size) print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size) # Average the accumulated gradients and optimize for p in maml.parameters(): p.grad.data.mul_(1.0 / meta_batch_size) opt.step() end_time = time.time() running_time[iteration] = end_time - start_time print('total running time', end_time - start_time) # meta_test_error = 0.0 # meta_test_accuracy = 0.0 # for task in range(meta_batch_size): # # Compute meta-testing loss # learner = maml.clone() # batch = test_tasks.sample() # evaluation_error, evaluation_accuracy = fast_adapt(batch, # learner, # loss, # adaptation_steps, # shots, # ways, # device) # meta_test_error += evaluation_error.item() # meta_test_accuracy += evaluation_accuracy.item() # print('Meta Test Error', meta_test_error / meta_batch_size) # print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size) return training_accuracy.numpy(), test_accuracy.numpy(), running_time
def tiered_imagenet_tasksets( train_ways=5, train_samples=10, test_ways=5, test_samples=10, root='~/data', data_augmentation=None, device=None, **kwargs, ): """Tasksets for tiered-ImageNet benchmarks.""" if data_augmentation is None: to_tensor = ToTensor() if device is None else lambda x: x train_data_transforms = Compose([ to_tensor, ]) test_data_transforms = train_data_transforms elif data_augmentation == 'normalize': to_tensor = ToTensor() if device is None else lambda x: x train_data_transforms = Compose([ to_tensor, ]) test_data_transforms = train_data_transforms elif data_augmentation == 'lee2019': normalize = Normalize( mean=[120.39586422/255.0, 115.59361427/255.0, 104.54012653/255.0], std=[70.68188272/255.0, 68.27635443/255.0, 72.54505529/255.0], ) to_pil = ToPILImage() if device is not None else lambda x: x to_tensor = ToTensor() if device is None else lambda x: x train_data_transforms = Compose([ to_pil, RandomCrop(84, padding=8), ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), RandomHorizontalFlip(), ToTensor(), normalize, ]) test_data_transforms = Compose([ to_tensor, normalize, ]) else: raise('Invalid data_augmentation argument.') train_dataset = l2l.vision.datasets.TieredImagenet( root=root, mode='train', transform=ToTensor(), download=True, ) valid_dataset = l2l.vision.datasets.TieredImagenet( root=root, mode='validation', transform=ToTensor(), download=True, ) test_dataset = l2l.vision.datasets.TieredImagenet( root=root, mode='test', transform=ToTensor(), download=True, ) if device is None: train_dataset.transform = train_data_transforms valid_dataset.transform = test_data_transforms test_dataset.transform = test_data_transforms else: train_dataset = l2l.data.OnDeviceDataset( dataset=train_dataset, transform=train_data_transforms, device=device, ) valid_dataset = l2l.data.OnDeviceDataset( dataset=valid_dataset, transform=test_data_transforms, device=device, ) test_dataset = l2l.data.OnDeviceDataset( dataset=test_dataset, transform=test_data_transforms, device=device, ) train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, train_ways), KShots(train_dataset, train_samples), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] valid_transforms = [ NWays(valid_dataset, test_ways), KShots(valid_dataset, test_samples), LoadData(valid_dataset), ConsecutiveLabels(valid_dataset), RemapLabels(valid_dataset), ] test_transforms = [ NWays(test_dataset, test_ways), KShots(test_dataset, test_samples), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] _datasets = (train_dataset, valid_dataset, test_dataset) _transforms = (train_transforms, valid_transforms, test_transforms) return _datasets, _transforms
def get_few_shot_tasksets( root='data', dataset='cifar10-fc100', train_ways=5, train_samples=10, test_ways=5, test_samples=10, n_train_tasks=2000, n_test_tasks=1000, ): """ Fetch the train, valid, test meta tasks of given dataset. :param root: data directory. :param dataset: name of dataset. :param train_ways: number of ways of few-shot training. :param train_samples: number of each-way samples for a training task. :param test_ways: number of ways of few-shot evaluation and testing. :param test_samples: number of each-way samples for a valid or test task. :param n_train_tasks: total number of train tasks. :param n_test_tasks: total number of valid and test tasks. :return: """ train_dataset, valid_dataset, test_dataset = get_normal_tasksets( root=root, dataset=dataset) train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) train_transforms = [ NWays(train_dataset, train_ways), KShots(train_dataset, train_samples), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset), ] test_transforms = [ NWays(test_dataset, test_ways), KShots(test_dataset, test_samples), LoadData(test_dataset), RemapLabels(test_dataset), ConsecutiveLabels(test_dataset), ] # Instantiate the tasksets train_tasks = l2l.data.TaskDataset( dataset=train_dataset, task_transforms=train_transforms, num_tasks=n_train_tasks, ) valid_tasks = l2l.data.TaskDataset( dataset=valid_dataset, task_transforms=test_transforms, num_tasks=n_test_tasks, ) test_tasks = l2l.data.TaskDataset( dataset=test_dataset, task_transforms=test_transforms, num_tasks=n_test_tasks, ) return BenchmarkTasksets(train_tasks, valid_tasks, test_tasks)
def main( ways=4, shots=5, meta_lr=0.01, fast_lr=0.5, meta_batch_size=32, adaptation_steps=1, num_iterations=10000, cuda=True, seed=42, ): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device = torch.device('cpu') if cuda: torch.cuda.manual_seed(seed) device = torch.device('cuda') # Create dataset train_transforms = [ NWays(train_dataset, ways), KShots(train_dataset, shots * 2), LoadData(train_dataset), RemapLabels(train_dataset), ConsecutiveLabels(train_dataset) ] valid_transforms = [ NWays(valid_dataset, ways), KShots(valid_dataset, shots * 2), LoadData(valid_dataset), RemapLabels(valid_dataset), ConsecutiveLabels(valid_dataset) ] train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms, num_tasks=1000) valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=200) # Create model model = Discriminator(ways=ways) model.to(device) maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False) opt = optim.Adam(maml.parameters(), meta_lr) loss = nn.CrossEntropyLoss(reduction='mean') best_acc = 0.0 writer = SummaryWriter("/home/sever2users/Desktop/MetaGAN/maml/viz") for iteration in range(num_iterations): opt.zero_grad() meta_train_error = 0.0 meta_train_accuracy = 0.0 meta_valid_error = 0.0 meta_valid_accuracy = 0.0 for task in range(meta_batch_size): # Compute meta-training loss learner = maml.clone() batch = train_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, loss, adaptation_steps, shots, ways, device) evaluation_error.backward() meta_train_error += evaluation_error.item() meta_train_accuracy += evaluation_accuracy.item() # Compute meta-validation loss learner = maml.clone() batch = valid_tasks.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, learner, loss, adaptation_steps, shots, ways, device) meta_valid_error += evaluation_error.item() meta_valid_accuracy += evaluation_accuracy.item() # Print some metrics if iteration % 500 == 0: print('\n') print('Iteration', iteration) print('Meta Train Error', meta_train_error / meta_batch_size) print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size) print('Meta Valid Error', meta_valid_error / meta_batch_size) print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size) writer.add_scalar( "{}way_{}shot_{}train_as/accuracy/valid".format( ways, shots, adaptation_steps), (meta_valid_accuracy / meta_batch_size), iteration) writer.add_scalar( "{}way_{}shot_{}train_as/accuracy/train".format( ways, shots, adaptation_steps), (meta_train_accuracy / meta_batch_size), iteration) writer.add_scalar( "way_{}shot_{}train_as/loss/valid".format(ways, shots, adaptation_steps), (meta_valid_error / meta_batch_size), iteration) writer.add_scalar( "way_{}shot_{}train_as/loss/train".format(ways, shots, adaptation_steps), (meta_train_error / meta_batch_size), iteration) if ((meta_valid_accuracy / meta_batch_size) > best_acc): maml_test = maml.clone() best_acc = (meta_valid_accuracy / meta_batch_size) # Average the accumulated gradients and optimize for p in maml.parameters(): p.grad.data.mul_(1.0 / meta_batch_size) opt.step() print("\n") print(ways, "WAYS", shots, "SHOTS", adaptation_steps, "train_adaptation_steps") x = [] for a_s in [1, 3, 10, 20]: for fast_lr in [0.5, 0.1, 0.05, 0.01]: l, a = test(fast_lr, loss, a_s, shots, ways, device, maml_test, valid_tasks) x.append([a_s, fast_lr, a]) print("test_a_s=", a_s, 'fast_lr', fast_lr, "acc=", a) np.savetxt( '/home/sever2users/Desktop/MetaGAN/maml/test_results/{}way_{}shot_{}train_adp_steps.txt' .format(ways, shots, adaptation_steps), np.array(x)) torch.save( maml_test.module.state_dict(), '/home/sever2users/Desktop/MetaGAN/maml/saved_models/state_dict_{}way_{}shot_{}train_adp_steps.pth' .format(ways, shots, adaptation_steps))
transformations = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )), lambda x: x.view(1, 28, 28), ]) # transformations = transforms.Compose([transforms.ToTensor()]) # d=CustomDatasetFromCsvLocation("./mnist-demo.csv") d = CustomDatasetFromCsvData('./mnist-demo.csv', 28, 28, transformations) # import torch.dataset t = l2l.data.MetaDataset(d) train_tasks = l2l.data.TaskDataset(t, task_transforms=[ NWays(t, n=3), KShots(t, k=2), LoadData(t), RemapLabels(t), ConsecutiveLabels(t) ], num_tasks=1000) model = Net(3) # model = Net(ways) maml_lr = 0.01 lr = 0.005 iterations = 1000 tps = 32 fas = 5 shots = 1 ways = 3 model.to(device)