Beispiel #1
0
    def __train__(self, data, exemplars, net, n_classes, stabilize=False):
        step = int(n_classes / 10) - 1
        BATCH_SIZE = self.params['BATCH_SIZE']
        MOMENTUM = self.params['MOMENTUM']
        WEIGHT_DECAY = self.params['WEIGHT_DECAY']
        lambda_ = self.params['lambda']

        if not stabilize:
            print('\n ### Update Representation ###')
            WEIGHT_DECAY = np.linspace(WEIGHT_DECAY, WEIGHT_DECAY / 10,
                                       10)[step]
            EPOCHS = self.params['EPOCHS']
            LR = self.params['LR']
            delta = self.params['delta']
            lambda_ += delta * (step - 1)
            milestones = set([49, 63])

            if len(exemplars) != 0:
                data = data + utils.formatExemplars(exemplars)
                # Save network for distillation
                old_net = deepcopy(net)
                old_net.eval()
                self.teachers.append(old_net)
                # Update network's last layer
                net = utils.updateNet(net, n_classes)

        else:
            print('\n ### Stabilize Network ###')
            EPOCHS = self.params['EPOCHS2']
            LR = self.params['LR2']
            milestones = set([int(EPOCHS / 3), int(2 * EPOCHS / 3)])
            data = utils.formatExemplars(exemplars)

        # Define Loss
        criterion = MSELoss()
        # Define Dataloader
        loader = DataLoader(data,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=4,
                            drop_last=True)
        net = net.to(self.device)
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=LR,
                                    momentum=MOMENTUM,
                                    weight_decay=WEIGHT_DECAY)

        for epoch in range(EPOCHS):

            # LR step down policy
            if epoch + 1 in milestones:
                for g in optimizer.param_groups:
                    g['lr'] = g['lr'] / 5

            # Set module in training mode
            net.train()

            running_loss = 0.0
            for images, labels in loader:
                # Data augmentation
                images = images.to(self.device)
                images = torch.stack(
                    [utils.augmentation(image) for image in images])
                # Get One Hot Encoding for the labels
                labels = utils.getOneHot(labels, n_classes)
                labels = labels.to(self.device)

                # Zero-ing the gradients
                optimizer.zero_grad()
                # Forward pass to the network
                outputs = torch.sigmoid(net(images))

                # Compute Losses
                if n_classes == 10 or stabilize:
                    tot_loss = criterion(outputs, labels)
                else:
                    with torch.no_grad():
                        old_outputs = torch.sigmoid(
                            self.__getOldOutputs__(n_classes, images))
                    class_loss = criterion(outputs, labels)
                    distill_loss = criterion(
                        torch.pow(outputs[:, :n_classes - 10], 1 / 2),
                        torch.pow(old_outputs, 1 / 2))
                    tot_loss = class_loss + distill_loss * lambda_

                # Update Running Loss
                running_loss += tot_loss.item() * images.size(0)

                tot_loss.backward()
                optimizer.step()

            # Train loss of current epoch
            train_loss = running_loss / len(data)
            print('\r   # Epoch: {}/{}, LR = {},  Train loss = {}'.format(
                epoch + 1, EPOCHS, optimizer.param_groups[0]['lr'],
                round(train_loss, 5)),
                  end='')
        print()

        return net
Beispiel #2
0
    def __trainTask__(self, data, net, n_classes):
        print('Training task')
        BATCH_SIZE = self.params['BATCH_SIZE']
        MOMENTUM = self.params['MOMENTUM']
        WEIGHT_DECAY = self.params['WEIGHT_DECAY']
        EPOCHS = self.params['EPOCHS']
        LR = self.params['LR']
        milestones = set([int(7 / 10 * EPOCHS), int(9 / 10 * EPOCHS)])

        # Define Loss
        criterion = MSELoss()
        # Define Dataloader
        loader = DataLoader(data,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=4,
                            drop_last=True)

        net.fc = nn.Linear(64, 10)
        net = net.to(self.device)
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=LR,
                                    momentum=MOMENTUM,
                                    weight_decay=WEIGHT_DECAY)

        for epoch in range(EPOCHS):

            # LR step down policy
            if epoch + 1 in milestones:
                for g in optimizer.param_groups:
                    g['lr'] = g['lr'] / 5

            # Set module in training mode
            net.train()

            running_loss = 0.0
            for images, labels in loader:
                # Data augmentation
                images = images.to(self.device)
                images = torch.stack(
                    [utils.augmentation(image) for image in images])
                # Get One Hot Encoding for the labels
                labels = torch.tensor(
                    [label - (n_classes - 10) for label in labels])
                labels = utils.getOneHot(labels, 10)
                labels = labels.to(self.device)

                # Zero-ing the gradients
                optimizer.zero_grad()
                # Forward pass to the network
                outputs = torch.sigmoid(net(images))

                # Compute Losses
                tot_loss = criterion(outputs, labels)

                # Update Running Loss
                running_loss += tot_loss.item() * images.size(0)

                tot_loss.backward()
                optimizer.step()

            # Train loss of current epoch
            train_loss = running_loss / len(data)
            print('\r   # Epoch: {}/{}, LR = {},  Train loss = {}'.format(
                epoch + 1, EPOCHS, optimizer.param_groups[0]['lr'],
                round(train_loss, 5)),
                  end='')
        print()

        self.nets.append(deepcopy(net))

        return net
Beispiel #3
0
    def __updateRepresentation__(self,
                                 data,
                                 exemplars,
                                 net,
                                 n_classes,
                                 fineTune=False):
        print('\n ### Update Representation ###')
        EPOCHS = self.params['EPOCHS']
        BATCH_SIZE = self.params['BATCH_SIZE']
        LR = self.params['LR']
        MOMENTUM = self.params['MOMENTUM']
        WEIGHT_DECAY = self.params['WEIGHT_DECAY']

        # Decrease the Weight Decay by a factor equal to one order of magnitude
        # less than the original value times the number of incremental steps
        if self.decay_policy:
            step = int(n_classes / 10) - 1
            WEIGHT_DECAY = np.linspace(WEIGHT_DECAY, WEIGHT_DECAY / 10,
                                       10)[step]

        # Define Loss
        criterion = BCEWithLogitsLoss()

        if len(exemplars) != 0:
            data = data + exemplars

        # Define Dataloader
        loader = DataLoader(data,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=4,
                            drop_last=True)

        if n_classes != 10:
            # Save network for distillation
            old_net = deepcopy(net)
            old_net.eval()
            # Update network's last layer
            net = utils.updateNet(net, n_classes)

        net = net.to(self.device)
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=LR,
                                    momentum=MOMENTUM,
                                    weight_decay=WEIGHT_DECAY)

        for epoch in range(EPOCHS):

            # LR step down policy
            if epoch == 48 or epoch == 62:
                for g in optimizer.param_groups:
                    g['lr'] = g['lr'] / 5

            # Set module in training mode
            net.train()

            running_loss = 0.0
            for images, labels in loader:
                images = images.to(self.device)
                images = torch.stack(
                    [utils.augmentation(image) for image in images])

                # Zero-ing the gradients
                optimizer.zero_grad()
                # Forward pass to the network
                outputs = net(images)
                # Get One Hot Encoding for the labels
                labels = utils.getOneHot(labels, n_classes)
                labels = labels.to(self.device)

                # Compute Losses
                if n_classes == 10 or fineTune:
                    tot_loss = criterion(outputs, labels)
                else:
                    with torch.no_grad():
                        old_outputs = torch.sigmoid(old_net(images))
                    targets = torch.cat(
                        (old_outputs, labels[:, n_classes - 10:]), 1)
                    tot_loss = criterion(outputs, targets)

                # Update Running Loss
                running_loss += tot_loss.item() * images.size(0)

                tot_loss.backward()
                optimizer.step()

            # Train loss of current epoch
            train_loss = running_loss / len(data)
            print('\r   # Epoch: {}/{}, LR = {},  Train loss = {}'.format(
                epoch + 1, EPOCHS, optimizer.param_groups[0]['lr'],
                round(train_loss, 5)),
                  end='')
        print()

        return net