def test_dataloader(self):
     dl = Datasets.random_dataloader_with_targets(4, 8, 9, 2)
     shapes = [[list(i.shape) for i in sample] for sample in dl]
     # shapes = [list(i.shape) for i in dl]
     assert [[4, 8], [4, 1]] == shapes[0]
     assert [[4, 8], [4, 1]] == shapes[1]
     assert [[1, 8], [1, 1]] == shapes[2]
     targets = [sample[1] for sample in dl]
     assert 1 == torch.max(torch.cat(targets, dim=0))
示例#2
0
def load_teacher(experiment_id: int,
                 epoch: int) -> Tuple[MNISTTeacher, Config]:
    reader = sacred_utils.get_reader(experiment_id)
    config = Config(**reader.config)
    learner_factory = create_learner_factory(config)
    teacher = create_teacher(config)
    teacher.eval()
    # teacher_loader = Datasets.random_dataloader_with_targets(config.batch_size, config.teacher_input_size, config.batch_size * epochs, config.mnist_classes, device='cuda')
    mnist_loader = Datasets.mnist_dataloader(config.batch_size, train=False)
    reader.load_model(teacher, 'teacher', epoch=epoch)
    return teacher, config
示例#3
0
def generate_teacher_images(batch_size: int, teacher: MNISTTeacher):
    loader = Datasets.random_dataloader_with_targets(
        batch_size,
        teacher.config.input_size,
        batch_size,
        teacher.config.target_classes,
        device='cuda')
    with torch.no_grad():
        (data, target) = next(x for x in loader)
        target_one_hot = id_to_one_hot(
            target, teacher.config.target_classes).squeeze(1)
        teacher_output, teacher_target = teacher(data, target_one_hot)
    return teacher_output, teacher_target
示例#4
0
def generate_teacher_images_for_class(batch_size: int, teacher: MNISTTeacher,
                                      cls: int, config: Config):
    # loader = Datasets.random_dataloader(batch_size, config.teacher_input_size, batch_size, device='cuda')
    loader = Datasets.random_dataloader_with_targets(batch_size,
                                                     config.teacher_input_size,
                                                     batch_size,
                                                     config.mnist_classes,
                                                     device='cuda')
    with torch.no_grad():
        data, target = next(x for x in loader)
        classes = torch.tensor([[cls]], device='cuda').expand(batch_size, 1)
        # classes = torch.randint(7, 9, (batch_size, 1), device = 'cuda').expand(batch_size, 1)
        target_one_hot = id_to_one_hot(classes,
                                       config.mnist_classes).squeeze(1)
        teacher_output, teacher_target = teacher(data, target_one_hot)
    return teacher_output, teacher_target
示例#5
0
def run_training_step(teacher: MNISTTeacher, c: Config,
                      input_data: Tensor) -> CompoundObserver:
    learner_factory = create_learner_factory(c)
    # train_data = TrainData(lambda: shuffle(teacher.input_data) if shuffle_inputs else teacher.input_data, teacher.input_target)
    train_data = TrainData(lambda: input_data, teacher.input_target)
    test_loader = Datasets.mnist_dataloader(c.batch_size, train=False)
    learning_loop = LearningLoop(teacher,
                                 learner_factory,
                                 train_data,
                                 test_loader,
                                 learning_rate=c.learning_rate,
                                 learning_rate_learner=c.learning_rate_learner,
                                 train_samples=c.train_samples)

    observer = CompoundObserver(ObserverLevel.training)
    learning_loop.train_step(observer, c.inner_loop_steps, c.mnist_classes)
    return observer
示例#6
0
def main(_run: Run, _config):
    c = Config(**_config)
    learner_factory = create_learner_factory(c)
    teacher = create_teacher(c)
    train_data = TrainData(lambda: teacher.input_data, teacher.input_target)
    test_loader = Datasets.mnist_dataloader(c.batch_size, train=False)

    learning_loop = LearningLoop(teacher,
                                 learner_factory,
                                 train_data,
                                 test_loader,
                                 learning_rate=c.learning_rate,
                                 learning_rate_learner=c.learning_rate_learner,
                                 train_samples=c.train_samples)
    for epoch in tqdm(range(1, c.epochs + 1)):
        observer = CompoundObserver(
            ObserverLevel.training
        ) if epoch % c.experiment_loss_save_period == 0 else None
        learning_loop.train_step(observer, c.inner_loop_steps, c.mnist_classes)
        if observer is not None:
            sacred_writer.save_observer(observer.main, epoch)

        if epoch % c.experiment_agent_save_period == 0:
            sacred_writer.save_model(teacher, 'teacher', epoch)
示例#7
0
    def train_step(self, observer: Optional[CompoundObserver],
                   inner_loop_steps: int, target_classes: int):
        # self.learner.reset()

        for model in self.models:
            model.train()

        teacher_input, teacher_input_target = self.train_data.data.to(
            self.device), self.train_data.target.to(self.device)
        self.optimizer_teacher.zero_grad()

        losses = []

        learner = self.learner_factory()
        learner.to(self.device)
        learner.train()

        learner_lr = self.teacher.learner_optim_params[0]
        learner_momentum = self.teacher.learner_optim_params[1]
        optim = SGD(learner.parameters(),
                    lr=learner_lr.item(),
                    momentum=learner_momentum.item())
        with higher.innerloop_ctx(learner,
                                  optim,
                                  override={
                                      'lr': [learner_lr],
                                      'momentum': [learner_momentum]
                                  }) as (flearner, diffopt):
            for step in range(inner_loop_steps):
                teacher_output, teacher_target = self.teacher(
                    teacher_input[step],
                    id_to_one_hot(teacher_input_target[step],
                                  target_classes).squeeze(1))
                learner_output = flearner(teacher_output)
                # loss = F.nll_loss(learner_output, teacher_target)
                # loss = F.cross_entropy(learner_output, teacher_target)
                loss = F.kl_div(learner_output, teacher_target)
                diffopt.step(loss)

                losses.append(loss)
                if observer is not None:
                    o = observer.rollout.add_observer()
                    o.add_tensor('teacher_output', teacher_output[0],
                                 ObserverLevel.inference)
                    o.add_tensor('teacher_target', teacher_target[0],
                                 ObserverLevel.inference)
                    o.add_tensor('learner_output', learner_output[0],
                                 ObserverLevel.inference)

            # test on Train MNIST
            # train_samples = 512*3
            # loss = torch.zeros([1], device=self.device)
            train_loader = Datasets.mnist_dataloader(self.train_samples,
                                                     train=True)
            correct = 0
            # train_batches_limit = 3
            # for data, target in train_loader:
            data, target = next(x for x in train_loader)
            data, target = data.to(self.device), target.to(self.device)
            output = flearner(data)
            loss = F.nll_loss(output, target)
            pred = output.argmax(
                dim=1,
                keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            # train_batches_limit -= 1
            # if train_batches_limit == 0:
            #     break
            accuracy_train = correct / self.train_samples

            # Compute accuracy on Test MNIST
            test_batch_size = 512
            test_loader = Datasets.mnist_dataloader(test_batch_size,
                                                    train=False)
            correct = 0
            for data, target in test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = flearner(data)
                pred = output.argmax(
                    dim=1,
                    keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
            accuracy = correct / len(test_loader.dataset)

            if observer is not None:
                observer.main.add_scalar('loss', loss.item())
                observer.main.add_scalar('accuracy_train', accuracy_train)
                observer.main.add_scalar('accuracy', accuracy)
                observer.main.add_scalar('learner_lr', learner_lr.item())
                observer.main.add_scalar('learner_momentum',
                                         learner_momentum.item())
                observer.main.add_tensor('teacher_input', teacher_input,
                                         ObserverLevel.training)
                observer.main.add_tensor('teacher_target',
                                         teacher_input_target,
                                         ObserverLevel.training)
            loss.backward()

        self.optimizer_teacher.step()
示例#8
0
 def test_dataloader(self):
     dl = Datasets.random_dataloader(4, 8, 9, 'cpu')
     shapes = [list(i.shape) for i in dl]
     assert [4, 8] == shapes[0]
     assert [4, 8] == shapes[1]
     assert [1, 8] == shapes[2]