def test_dataloader(self): dl = Datasets.random_dataloader_with_targets(4, 8, 9, 2) shapes = [[list(i.shape) for i in sample] for sample in dl] # shapes = [list(i.shape) for i in dl] assert [[4, 8], [4, 1]] == shapes[0] assert [[4, 8], [4, 1]] == shapes[1] assert [[1, 8], [1, 1]] == shapes[2] targets = [sample[1] for sample in dl] assert 1 == torch.max(torch.cat(targets, dim=0))
def load_teacher(experiment_id: int, epoch: int) -> Tuple[MNISTTeacher, Config]: reader = sacred_utils.get_reader(experiment_id) config = Config(**reader.config) learner_factory = create_learner_factory(config) teacher = create_teacher(config) teacher.eval() # teacher_loader = Datasets.random_dataloader_with_targets(config.batch_size, config.teacher_input_size, config.batch_size * epochs, config.mnist_classes, device='cuda') mnist_loader = Datasets.mnist_dataloader(config.batch_size, train=False) reader.load_model(teacher, 'teacher', epoch=epoch) return teacher, config
def generate_teacher_images(batch_size: int, teacher: MNISTTeacher): loader = Datasets.random_dataloader_with_targets( batch_size, teacher.config.input_size, batch_size, teacher.config.target_classes, device='cuda') with torch.no_grad(): (data, target) = next(x for x in loader) target_one_hot = id_to_one_hot( target, teacher.config.target_classes).squeeze(1) teacher_output, teacher_target = teacher(data, target_one_hot) return teacher_output, teacher_target
def generate_teacher_images_for_class(batch_size: int, teacher: MNISTTeacher, cls: int, config: Config): # loader = Datasets.random_dataloader(batch_size, config.teacher_input_size, batch_size, device='cuda') loader = Datasets.random_dataloader_with_targets(batch_size, config.teacher_input_size, batch_size, config.mnist_classes, device='cuda') with torch.no_grad(): data, target = next(x for x in loader) classes = torch.tensor([[cls]], device='cuda').expand(batch_size, 1) # classes = torch.randint(7, 9, (batch_size, 1), device = 'cuda').expand(batch_size, 1) target_one_hot = id_to_one_hot(classes, config.mnist_classes).squeeze(1) teacher_output, teacher_target = teacher(data, target_one_hot) return teacher_output, teacher_target
def run_training_step(teacher: MNISTTeacher, c: Config, input_data: Tensor) -> CompoundObserver: learner_factory = create_learner_factory(c) # train_data = TrainData(lambda: shuffle(teacher.input_data) if shuffle_inputs else teacher.input_data, teacher.input_target) train_data = TrainData(lambda: input_data, teacher.input_target) test_loader = Datasets.mnist_dataloader(c.batch_size, train=False) learning_loop = LearningLoop(teacher, learner_factory, train_data, test_loader, learning_rate=c.learning_rate, learning_rate_learner=c.learning_rate_learner, train_samples=c.train_samples) observer = CompoundObserver(ObserverLevel.training) learning_loop.train_step(observer, c.inner_loop_steps, c.mnist_classes) return observer
def main(_run: Run, _config): c = Config(**_config) learner_factory = create_learner_factory(c) teacher = create_teacher(c) train_data = TrainData(lambda: teacher.input_data, teacher.input_target) test_loader = Datasets.mnist_dataloader(c.batch_size, train=False) learning_loop = LearningLoop(teacher, learner_factory, train_data, test_loader, learning_rate=c.learning_rate, learning_rate_learner=c.learning_rate_learner, train_samples=c.train_samples) for epoch in tqdm(range(1, c.epochs + 1)): observer = CompoundObserver( ObserverLevel.training ) if epoch % c.experiment_loss_save_period == 0 else None learning_loop.train_step(observer, c.inner_loop_steps, c.mnist_classes) if observer is not None: sacred_writer.save_observer(observer.main, epoch) if epoch % c.experiment_agent_save_period == 0: sacred_writer.save_model(teacher, 'teacher', epoch)
def train_step(self, observer: Optional[CompoundObserver], inner_loop_steps: int, target_classes: int): # self.learner.reset() for model in self.models: model.train() teacher_input, teacher_input_target = self.train_data.data.to( self.device), self.train_data.target.to(self.device) self.optimizer_teacher.zero_grad() losses = [] learner = self.learner_factory() learner.to(self.device) learner.train() learner_lr = self.teacher.learner_optim_params[0] learner_momentum = self.teacher.learner_optim_params[1] optim = SGD(learner.parameters(), lr=learner_lr.item(), momentum=learner_momentum.item()) with higher.innerloop_ctx(learner, optim, override={ 'lr': [learner_lr], 'momentum': [learner_momentum] }) as (flearner, diffopt): for step in range(inner_loop_steps): teacher_output, teacher_target = self.teacher( teacher_input[step], id_to_one_hot(teacher_input_target[step], target_classes).squeeze(1)) learner_output = flearner(teacher_output) # loss = F.nll_loss(learner_output, teacher_target) # loss = F.cross_entropy(learner_output, teacher_target) loss = F.kl_div(learner_output, teacher_target) diffopt.step(loss) losses.append(loss) if observer is not None: o = observer.rollout.add_observer() o.add_tensor('teacher_output', teacher_output[0], ObserverLevel.inference) o.add_tensor('teacher_target', teacher_target[0], ObserverLevel.inference) o.add_tensor('learner_output', learner_output[0], ObserverLevel.inference) # test on Train MNIST # train_samples = 512*3 # loss = torch.zeros([1], device=self.device) train_loader = Datasets.mnist_dataloader(self.train_samples, train=True) correct = 0 # train_batches_limit = 3 # for data, target in train_loader: data, target = next(x for x in train_loader) data, target = data.to(self.device), target.to(self.device) output = flearner(data) loss = F.nll_loss(output, target) pred = output.argmax( dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() # train_batches_limit -= 1 # if train_batches_limit == 0: # break accuracy_train = correct / self.train_samples # Compute accuracy on Test MNIST test_batch_size = 512 test_loader = Datasets.mnist_dataloader(test_batch_size, train=False) correct = 0 for data, target in test_loader: data, target = data.to(self.device), target.to(self.device) output = flearner(data) pred = output.argmax( dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() accuracy = correct / len(test_loader.dataset) if observer is not None: observer.main.add_scalar('loss', loss.item()) observer.main.add_scalar('accuracy_train', accuracy_train) observer.main.add_scalar('accuracy', accuracy) observer.main.add_scalar('learner_lr', learner_lr.item()) observer.main.add_scalar('learner_momentum', learner_momentum.item()) observer.main.add_tensor('teacher_input', teacher_input, ObserverLevel.training) observer.main.add_tensor('teacher_target', teacher_input_target, ObserverLevel.training) loss.backward() self.optimizer_teacher.step()
def test_dataloader(self): dl = Datasets.random_dataloader(4, 8, 9, 'cpu') shapes = [list(i.shape) for i in dl] assert [4, 8] == shapes[0] assert [4, 8] == shapes[1] assert [1, 8] == shapes[2]