コード例 #1
0
    def train_VAE(self, max_epochs, train_dataloader, validation_dataloader):
        early_stopping = EarlyStopping('{}/{}_autoencoder.pt'.format(
            self.state_path, self.model_name),
                                       patience=10)

        for epoch in range(max_epochs):
            if early_stopping.early_stop:
                break

            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_dataloader):
                self.VAE.train()

                data = data.to(self.device)

                self.VAE_optim.zero_grad()

                batch_params = self.VAE(data)

                loss = self.VAE_criterion(batch_params, data)

                train_loss += loss.item()

                loss.backward()
                self.VAE_optim.step()

            if validation_dataloader is not None:
                validation_loss = self.unsupervised_validation_loss(
                    validation_dataloader)
                early_stopping(validation_loss, self.VAE)

        if validation_dataloader is not None:
            early_stopping.load_checkpoint(self.VAE)
コード例 #2
0
    def train_classifier(self, max_epochs, train_dataloader,
                         validation_dataloader, comparison):
        epochs = []
        train_losses = []
        validation_accs = []

        early_stopping = EarlyStopping('{}/{}_classifier.pt'.format(
            self.model_name, self.dataset_name))

        for epoch in range(max_epochs):
            if early_stopping.early_stop:
                break

            for batch_idx, (data, labels) in enumerate(train_dataloader):
                self.Classifier.train()

                data = data.float().to(self.device)
                labels = labels.to(self.device)

                self.Classifier_optim.zero_grad()

                with torch.no_grad():
                    z = self.Encoder(data)

                pred = self.Classifier(z)

                loss = self.Classifier_criterion(pred, labels)

                loss.backward()
                self.Classifier_optim.step()

                if comparison:
                    epochs.append(epoch)
                    train_losses.append(loss.item())
                    validation_accs.append(
                        self.accuracy(validation_dataloader))

            val = self.accuracy(validation_dataloader)

            print('Supervised Epoch: {} Validation acc: {}'.format(epoch, val))

            early_stopping(1 - val, self.Classifier)

        if early_stopping.early_stop:
            early_stopping.load_checkpoint(self.Classifier)

        return epochs, train_losses, validation_accs
コード例 #3
0
    def train_classifier(self, max_epochs, train_dataloader,
                         validation_dataloader):
        epochs = []
        train_losses = []
        validation_accs = []

        early_stopping = EarlyStopping('{}/{}_classifier.pt'.format(
            self.state_path, self.model_name))

        for epoch in range(max_epochs):
            if early_stopping.early_stop:
                break

            train_loss = 0
            for batch_idx, (data, labels) in enumerate(train_dataloader):
                self.Classifier.train()

                data = data.to(self.device)
                labels = labels.to(self.device)

                self.Classifier_optim.zero_grad()

                with torch.no_grad():
                    z, _, _ = self.Encoder(data)

                pred = self.Classifier(z)

                loss = self.Classifier_criterion(pred, labels)

                loss.backward()
                self.Classifier_optim.step()

                train_loss += loss.item()

            if validation_dataloader is not None:
                acc = self.accuracy(validation_dataloader)
                validation_accs.append(acc)

                early_stopping(1 - acc, self.Classifier)

            epochs.append(epoch)
            train_losses.append(train_loss / len(train_dataloader))

        if validation_dataloader is not None:
            early_stopping.load_checkpoint(self.Classifier)

        return epochs, train_losses, validation_accs
コード例 #4
0
    def train_autoencoder(self, max_epochs, train_dataloader,
                          validation_dataloader):
        early_stopping = EarlyStopping('{}/{}_autoencoder.pt'.format(
            self.model_name, self.dataset_name),
                                       patience=10)

        for epoch in range(max_epochs):
            if early_stopping.early_stop:
                break

            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_dataloader):
                self.Autoencoder.train()

                data = data.to(self.device)

                self.Autoencoder_optim.zero_grad()

                recons = self.Autoencoder(data)

                loss = self.Autoencoder_criterion(recons, data)

                train_loss += loss.item()

                loss.backward()
                self.Autoencoder_optim.step()

            validation_loss = unsupervised_validation_loss(
                self.Autoencoder, validation_dataloader,
                self.Autoencoder_criterion, self.device)

            early_stopping(validation_loss, self.Autoencoder)

            print('Unsupervised Epoch: {} Loss: {} Validation loss: {}'.format(
                epoch, train_loss, validation_loss))
            # print('Unsupervised Loss: {}'.format(train_loss))

        if early_stopping.early_stop:
            early_stopping.load_checkpoint(self.Autoencoder)
コード例 #5
0
    def train_ladder(self, max_epochs, supervised_dataloader,
                     unsupervised_dataloader, validation_dataloader):
        epochs = []
        train_losses = []
        validation_accs = []

        early_stopping = EarlyStopping('{}/{}_inner.pt'.format(
            self.state_path, self.model_name))

        for epoch in range(max_epochs):
            if early_stopping.early_stop:
                break

            train_loss = 0
            for batch_idx, (labelled_data, unlabelled_data) in enumerate(
                    zip(cycle(supervised_dataloader),
                        unsupervised_dataloader)):
                self.ladder.train()

                self.optimizer.zero_grad()

                labelled_images, labels = labelled_data
                labelled_images = labelled_images.to(self.device)
                labels = labels.to(self.device)

                unlabelled_images, _ = unlabelled_data
                unlabelled_images = unlabelled_images.to(self.device)

                inputs = torch.cat((labelled_images, unlabelled_images), 0)

                batch_size = labelled_images.size(0)

                y_c, corr = self.ladder.forward_encoders(
                    inputs, self.noise_std, True, batch_size)
                y, clean = self.ladder.forward_encoders(
                    inputs, 0.0, True, batch_size)

                z_est_bn = self.ladder.forward_decoders(
                    F.softmax(y_c, dim=1), corr, clean, batch_size)

                cost = self.supervised_cost_function.forward(
                    labeled(y_c, batch_size), labels)

                zs = clean['unlabeled']['z']

                u_cost = 0
                for l in range(self.L, -1, -1):
                    u_cost += self.unsupervised_cost_function.forward(
                        z_est_bn[l], zs[l]) * self.denoising_cost[l]

                loss = cost + u_cost

                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()

            if validation_dataloader is not None:
                acc = self.accuracy(validation_dataloader, 0)
                validation_accs.append(acc)
                early_stopping(1 - acc, self.ladder)

            epochs.append(epoch)
            train_losses.append(train_loss / len(unsupervised_dataloader))

        if validation_dataloader is not None:
            early_stopping.load_checkpoint(self.ladder)

        return epochs, train_losses, validation_accs
コード例 #6
0
    def train_m2(self, max_epochs, labelled_loader, unlabelled_loader,
                 validation_loader):

        if unlabelled_loader is None:
            alpha = 1
        else:
            alpha = 0.1 * len(unlabelled_loader.dataset) / len(
                labelled_loader.dataset)

        epochs = []
        train_losses = []
        validation_accs = []

        early_stopping = EarlyStopping('{}/{}_inner.pt'.format(
            self.state_path, self.model_name))

        for epoch in range(max_epochs):
            if early_stopping.early_stop:
                break

            if unlabelled_loader is not None:
                data_iterator = zip(cycle(labelled_loader), unlabelled_loader)
            else:
                data_iterator = zip(labelled_loader, cycle([None]))

            train_loss = 0
            for batch_idx, (labelled_data,
                            unlabelled_data) in enumerate(data_iterator):
                self.M2.train()
                self.optimizer.zero_grad()

                labelled_images, labels = labelled_data
                labelled_images = labelled_images.float().to(self.device)
                labels = labels.to(self.device)

                labelled_predictions = self.M2.classify(labelled_images)
                labelled_loss = F.cross_entropy(labelled_predictions, labels)

                # labelled images ELBO
                L = self.elbo(labelled_images, y=labels)

                loss = L + alpha * labelled_loss

                if unlabelled_data is not None:
                    unlabelled_images, _ = unlabelled_data
                    unlabelled_images = unlabelled_images.float().to(
                        self.device)

                    U = self.elbo(unlabelled_images)

                    loss += U

                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()

            if validation_loader is not None:
                acc = self.accuracy(validation_loader)
                validation_accs.append(acc)
                early_stopping(1 - acc, self.M2)

            epochs.append(epoch)
            train_losses.append(train_loss)

        if validation_loader is not None:
            early_stopping.load_checkpoint(self.M2)

        return epochs, train_losses, validation_accs