Esempio n. 1
0
m = 5000
train_data = Subset(fmnist_train, m)
val_data = Subset(fmnist_train, np.arange(m, m + 200))
val_data2 = Subset(fmnist_train, np.arange(m - 200, m))
test_data = fmnist_test

net = LeNet(in_channels=1)
optimizer = SGD(net.parameters(), lr=0.003, momentum=0.9, weight_decay=5e-4)

batch_size = 64
data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
epochs = 100
for epoch in range(epochs):
    if epoch % 2 == 0:
        gc.collect()
    print("Epoch %d" % (epoch + 1))
    for i, batch in enumerate(data_loader):
        inputs, labels = batch
        inputs = H.tensor(inputs.numpy())
        labels = H.tensor(labels.numpy())

        net.zero_grad()
        output = net(inputs)
        loss = H.CrossEntropyLoss(output, labels)
        loss.backward()
        optimizer.step()
    optimizer.reduce(0.95)
    print(evaluate_dataset(net, val_data2, 64))
    print(evaluate_dataset(net, val_data, 64))
class Reptile(object):
    def __init__(self, args):
        self.args = args
        self._load_model()

        self.model.to(args.device)
        self.task_generator = TaskGen(args.max_num_classes)
        self.outer_stepsize = args.outer_stepsize
        self.criterion = nn.CrossEntropyLoss()
        # self.optimizer = optim.Adam(self.model.parameters(), lr=args.inner_stepsize)

    def _load_model(self):
        self.model = LeNet()
        self.current_iteration = 0
        if os.path.exists(self.args.model_path):
            try:
                print("Loading model from: {}".format(self.args.model_path))
                self.model.load_state_dict(torch.load(self.args.model_path))
                self.current_iteration = joblib.load("{}.iter".format(
                    self.args.model_path))
            except Exception as e:
                print(
                    "Exception: {}\nCould not load model from {} - starting from scratch"
                    .format(e, self.args.model_path))

    def inner_training(self, x, y, num_iterations):
        """
        Run training on task
        """
        x, y = shuffle_unison(x, y)

        self.model.train()

        x = torch.tensor(x, dtype=torch.float, device=self.args.device)
        y = torch.tensor(y, dtype=torch.float, device=self.args.device)

        total_loss = 0
        for _ in range(num_iterations):
            start = np.random.randint(0,
                                      len(x) - self.args.inner_batch_size + 1)

            self.model.zero_grad()
            # self.optimizer.zero_grad()
            outputs = self.model(x[start:start + self.args.inner_batch_size])
            # print("output: {} - y: {}".format(outputs.shape, y.shape))
            loss = self.criterion(
                outputs,
                Variable(y[start:start + self.args.inner_batch_size].long()))
            total_loss += loss
            loss.backward()
            # self.optimizer.step()
            # Similar to calling optimizer.step()
            for param in self.model.parameters():
                param.data -= self.args.inner_stepsize * param.grad.data
        return total_loss / self.args.inner_iterations

    def _meta_gradient_update(self, iteration, num_classes, weights_before):
        """
        Interpolate between current weights and trained weights from this task
        I.e. (weights_before - weights_after) is the meta-gradient

            - iteration: current iteration - used for updating outer_stepsize
            - num_classes: current classifier number of classes
            - weights_before: state of weights before inner steps training
        """
        weights_after = self.model.state_dict()
        outer_stepsize = self.outer_stepsize * (
            1 - iteration / self.args.n_iterations)  # linear schedule

        self.model.load_state_dict({
            name: weights_before[name] +
            (weights_after[name] - weights_before[name]) * outer_stepsize
            for name in weights_before
        })

    def meta_training(self):
        # Reptile training loop
        total_loss = 0
        try:
            while self.current_iteration < self.args.n_iterations:
                # Generate task
                data, labels, original_labels, num_classes = self.task_generator.get_train_task(
                    args.num_classes)

                weights_before = deepcopy(self.model.state_dict())
                loss = self.inner_training(data, labels,
                                           self.args.inner_iterations)
                total_loss += loss
                if self.current_iteration % self.args.log_every == 0:
                    print("-----------------------------")
                    print("iteration               {}".format(
                        self.current_iteration + 1))
                    print("Loss: {:.3f}".format(total_loss /
                                                (self.current_iteration + 1)))
                    print("Current task info: ")
                    print("\t- Number of classes: {}".format(num_classes))
                    print("\t- Batch size: {}".format(len(data)))
                    print("\t- Labels: {}".format(set(original_labels)))

                    self.test()

                self._meta_gradient_update(self.current_iteration, num_classes,
                                           weights_before)

                self.current_iteration += 1

            torch.save(self.model.state_dict(), self.args.model_path)

        except KeyboardInterrupt:
            print("Manual Interrupt...")
            print("Saving to: {}".format(self.args.model_path))
            torch.save(self.model.state_dict(), self.args.model_path)
            joblib.dump(self.current_iteration,
                        "{}.iter".format(self.args.model_path),
                        compress=1)

    def predict(self, x):
        self.model.eval()
        x = torch.tensor(x, dtype=torch.float, device=self.args.device)
        outputs = self.model(x)
        return outputs.cpu().data.numpy()

    def test(self):
        """
        Run tests
            1. Create task from test set.
            2. Reload model
            3. Check accuracy on test set
            4. Train for one or more iterations on one task
            5. Check accuracy again on test set
        """

        test_data, test_labels, _, _ = self.task_generator.get_test_task(
            selected_labels=[1, 2, 3, 4,
                             5], num_samples=-1)  # all available samples
        predicted_labels = np.argmax(self.predict(test_data), axis=1)
        accuracy = np.mean(1 * (predicted_labels == test_labels)) * 100
        print(
            "Accuracy before few shots learning (a.k.a. zero-shot learning): {:.2f}%\n----"
            .format(accuracy))

        weights_before = deepcopy(
            self.model.state_dict())  # save snapshot before evaluation
        for i in range(1, 5):
            enroll_data, enroll_labels, _, _ = self.task_generator.get_enroll_task(
                selected_labels=[1, 2, 3, 4, 5], num_samples=i)
            self.inner_training(enroll_data, enroll_labels,
                                self.args.inner_iterations_test)
            predicted_labels = np.argmax(self.predict(test_data), axis=1)
            accuracy = np.mean(1 * (predicted_labels == test_labels)) * 100

            print("Accuracy after {} shot{} learning: {:.2f}%)".format(
                i, "" if i == 1 else "s", accuracy))

        self.model.load_state_dict(weights_before)  # restore from snapshot