Ejemplo n.º 1
0
def train_mask():
    """
    Train explainable mask for an image from ImageNet, using pretrained model.
    """
    model = torchvision.models.vgg19(pretrained=True)
    model.eval()
    for param in model.parameters():
        param.requires_grad_(False)
    normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                 std=[0.229, 0.224, 0.225])
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(size=(224, 224)),
        torchvision.transforms.ToTensor(), normalize
    ])
    accuracy_measure = AccuracyArgmax()
    monitor = Monitor(test_loader=None, accuracy_measure=accuracy_measure)
    monitor.open(env_name='mask')
    monitor.normalize_inverse = NormalizeInverse(mean=normalize.mean,
                                                 std=normalize.std)
    image = Image.open(IMAGES_DIR / "flute.jpg")
    image = transform(image)
    mask_trainer = MaskTrainer(accuracy_measure=accuracy_measure,
                               image_shape=image.shape,
                               show_progress=True)
    monitor.log(repr(mask_trainer))
    if torch.cuda.is_available():
        model = model.cuda()
        image = image.cuda()
    outputs = model(image.unsqueeze(dim=0))
    proba = accuracy_measure.predict_proba(outputs)
    proba_max, label_true = proba[0].max(dim=0)
    print(f"True label: {label_true} (confidence {proba_max: .5f})")
    monitor.plot_mask(model=model,
                      mask_trainer=mask_trainer,
                      image=image,
                      label=label_true)
Ejemplo n.º 2
0
class Trainer(ABC):
    watch_modules = (nn.Linear, nn.Conv2d)

    def __init__(self,
                 model: nn.Module,
                 criterion: nn.Module,
                 dataset_name: str,
                 accuracy_measure: Accuracy = None,
                 env_suffix='',
                 checkpoint_dir=CHECKPOINTS_DIR):
        if torch.cuda.is_available():
            model = model.cuda()
        self.model = model
        self.criterion = criterion
        self.dataset_name = dataset_name
        self.checkpoint_dir = Path(checkpoint_dir)
        self.train_loader = get_data_loader(dataset_name, train=True)
        self.timer = timer
        self.timer.init(batches_in_epoch=len(self.train_loader))
        self.env_name = f"{time.strftime('%Y.%m.%d')} {self.model.__class__.__name__}: " \
                        f"{self.dataset_name} {self.__class__.__name__} {self.criterion.__class__.__name__}"
        if env_suffix:
            self.env_name = self.env_name + f' {env_suffix}'
        if accuracy_measure is None:
            if isinstance(self.criterion, PairLoss):
                accuracy_measure = AccuracyEmbedding()
            else:
                # cross entropy loss
                accuracy_measure = AccuracyArgmax()
        self.accuracy_measure = accuracy_measure
        self.monitor = Monitor(test_loader=get_data_loader(self.dataset_name,
                                                           train=False),
                               accuracy_measure=self.accuracy_measure)
        for name, layer in find_named_layers(self.model,
                                             layer_class=self.watch_modules):
            self.monitor.register_layer(layer, prefix=name)
        images, labels = next(iter(self.train_loader))
        self.mask_trainer = MaskTrainer(accuracy_measure=self.accuracy_measure,
                                        image_shape=images[0].shape)

    @property
    def checkpoint_path(self):
        return self.checkpoint_dir / (self.env_name + '.pt')

    def monitor_functions(self):
        pass

    def log_trainer(self):
        self.monitor.log(f"Criterion: {self.criterion}")
        self.monitor.log(repr(self.mask_trainer))

    @abstractmethod
    def train_batch(self, images, labels):
        raise NotImplementedError()

    def save(self):
        self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
        try:
            torch.save(self.state_dict(), self.checkpoint_path)
        except PermissionError as error:
            print(error)

    def state_dict(self):
        return {
            "model_state": self.model.state_dict(),
            "epoch": self.timer.epoch,
            "env_name": self.env_name,
        }

    def restore(self, checkpoint_path=None, strict=True):
        """
        :param checkpoint_path: train checkpoint path to restore
        :param strict: model's load_state_dict strict argument
        """
        if checkpoint_path is None:
            checkpoint_path = self.checkpoint_path
        if not checkpoint_path.exists():
            print(
                f"Checkpoint '{checkpoint_path}' doesn't exist. Nothing to restore."
            )
            return None
        map_location = None
        if not torch.cuda.is_available():
            map_location = 'cpu'
        checkpoint_state = torch.load(checkpoint_path,
                                      map_location=map_location)
        try:
            self.model.load_state_dict(checkpoint_state['model_state'],
                                       strict=strict)
        except RuntimeError as error:
            print(
                f"Error is occurred while restoring {checkpoint_path}: {error}"
            )
            return None
        self.env_name = checkpoint_state['env_name']
        self.timer.set_epoch(checkpoint_state['epoch'])
        self.monitor.open(env_name=self.env_name)
        print(f"Restored model state from {checkpoint_path}.")
        return checkpoint_state

    def _epoch_finished(self, epoch, outputs, labels):
        loss = self.criterion(outputs, labels)
        self.monitor.update_loss(loss, mode='full train')
        self.save()
        return loss

    def train_mask(self):
        """
        Train mask to see what part of the image is crucial from the network perspective.
        """
        images, labels = next(iter(self.train_loader))
        mode_saved = prepare_eval(self.model)
        if torch.cuda.is_available():
            images = images.cuda()
        with torch.no_grad():
            proba = self.accuracy_measure.predict_proba(self.model(images))
        proba_max, _ = proba.max(dim=1)
        sample_max_proba = proba_max.argmax()
        image = images[sample_max_proba]
        label = labels[sample_max_proba]
        self.monitor.plot_mask(self.model,
                               mask_trainer=self.mask_trainer,
                               image=image,
                               label=label)
        mode_saved.restore(self.model)
        return image, label

    def get_adversarial_examples(self, noise_ampl=100, n_iter=10):
        """
        :param noise_ampl: adversarial noise amplitude
        :param n_iter: adversarial iterations
        :return adversarial examples
        """
        images, labels = next(iter(self.train_loader))
        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()
        images_orig = images.clone()
        images.requires_grad_(True)
        mode_saved = prepare_eval(self.model)
        for i in range(n_iter):
            images.grad = None
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            loss.backward()
            with torch.no_grad():
                adv_noise = noise_ampl * images.grad
                images += adv_noise
        images.requires_grad_(False)
        mode_saved.restore(self.model)
        return AdversarialExamples(original=images_orig,
                                   adversarial=images,
                                   labels=labels)

    def update_batch_accuracy(self, outputs, labels):
        self.accuracy_measure.save(outputs, labels)
        labels_predicted = self.accuracy_measure.predict(outputs)
        self.monitor.update_accuracy(accuracy=calc_accuracy(
            labels, labels_predicted),
                                     mode='batch')

    def train_epoch(self, epoch):
        """
        :param epoch: epoch id
        :return: last batch loss
        """
        loss_batch_average = MeanOnline()
        outputs = None
        use_cuda = torch.cuda.is_available()
        for images, labels in tqdm(self.train_loader,
                                   desc="Epoch {:d}".format(epoch),
                                   leave=False):
            if use_cuda:
                images = images.cuda()
                labels = labels.cuda()

            outputs, loss = self.train_batch(images, labels)
            loss_batch_average.update(loss.detach().cpu())
            for name, param in self.model.named_parameters():
                if torch.isnan(param).any():
                    warnings.warn(f"NaN parameters in '{name}'")
            self.monitor.batch_finished(self.model)

            # uncomment to see more detailed progress - at each batch instead of epoch
            # self.monitor.update_loss(loss=loss, mode='batch')
            # self.update_batch_accuracy(outputs, labels)
            # self.monitor.update_sparsity(outputs, mode='batch')
            # self.monitor.update_density(outputs, mode='batch')
            # self.monitor.activations_heatmap(outputs, labels)

        self.monitor.update_loss(loss=loss_batch_average.get_mean(),
                                 mode='batch')
        if not isinstance(self.accuracy_measure, AccuracyArgmax):
            self.monitor.update_sparsity(outputs, mode='batch')
            self.monitor.update_density(outputs, mode='batch')

    def train(self,
              n_epoch=10,
              epoch_update_step=1,
              mutual_info_layers=1,
              adversarial=False,
              mask_explain=False):
        """
        :param n_epoch: number of training epochs
        :param epoch_update_step: epoch step to run full evaluation
        :param mutual_info_layers: number of last layers to be monitored for mutual information;
                                   pass '0' to turn off this feature.
        :param adversarial: perform adversarial attack test?
        :param mask_explain: train the image mask that 'explains' network behaviour?
        """
        print(self.model)
        if not self.monitor.is_active:
            # new environment
            self.monitor.open(env_name=self.env_name)
            self.monitor.clear()
        self.monitor_functions()
        self.monitor.log_model(self.model)
        self.monitor.log_self()
        self.log_trainer()
        print(f"Training '{self.model.__class__.__name__}'")

        eval_loader = torch.utils.data.DataLoader(
            dataset=self.train_loader.dataset,
            batch_size=self.train_loader.batch_size,
            shuffle=False,
            num_workers=self.train_loader.num_workers)

        full_forward_pass_eval = partial(full_forward_pass, loader=eval_loader)
        update_wrapper(wrapper=full_forward_pass_eval,
                       wrapped=full_forward_pass)

        if mutual_info_layers > 0:
            full_forward_pass_eval = self.monitor.mutual_info.decorate_evaluation(
                full_forward_pass_eval)
            self.monitor.mutual_info.prepare(
                eval_loader,
                model=self.model,
                monitor_layers_count=mutual_info_layers)

        for epoch in range(self.timer.epoch, self.timer.epoch + n_epoch):
            self.train_epoch(epoch=epoch)
            if epoch % epoch_update_step == 0:
                outputs_full, labels_full = full_forward_pass_eval(self.model)
                self.accuracy_measure.save(outputs_train=outputs_full,
                                           labels_train=labels_full)
                self.monitor.epoch_finished(self.model, outputs_full,
                                            labels_full)
                if adversarial:
                    self.monitor.plot_adversarial_examples(
                        self.model, self.get_adversarial_examples())
                if mask_explain:
                    self.train_mask()
                self._epoch_finished(epoch, outputs_full, labels_full)