Esempio n. 1
0
    def on_epoch_end(self, tracker: MetricTracker, storage: SharedStorage):
        if self.train:
            qf = gf = storage.get_data("features")
            qpids = gpids = storage.get_data("pids")
            qcamids = gcamids = storage.get_data("camids")
            prefix = ""
        else:
            qf = storage.get_data("qf")
            gf = storage.get_data("gf")
            qpids = storage.get_data("qpids")
            gpids = storage.get_data("gpids")
            qcamids = storage.get_data("qcamids")
            gcamids = storage.get_data("gcamids")
            prefix = "valid_"

        distmat = storage.get_data("distmat")
        if distmat is None:
            distmat = compute_distances(qf, gf)
            storage.set_data("distmat", distmat)
        same_pid = gpids.eq(qpids.reshape(-1, 1))
        same_cam = gcamids.eq(qcamids.reshape(-1, 1))
        negative: torch.Tensor = ~same_pid
        positive: torch.Tensor = same_pid
        positive_same_cam = torch.logical_and(same_pid, same_cam)
        positive_diff_cam = torch.logical_and(same_pid, ~same_cam)

        if self.train:
            # filter out identical instances from positive distances
            same_image = torch.diagflat(torch.ones(qf.size(0), dtype=torch.bool, device=qf.device))
            positive.logical_and_(~same_image)
            positive_same_cam.logical_and_(~same_image)

        tracker.update(prefix + "global_dist_pos_same_cam_mean", distmat[positive_same_cam].mean().item())
        tracker.update(prefix + "global_dist_pos_diff_cam_mean", distmat[positive_diff_cam].mean().item())
        tracker.update(prefix + "global_dist_pos_mean", distmat[positive].mean().item())
        tracker.update(prefix + "global_dist_neg_mean", distmat[negative].mean().item())
        tracker.append_histogram(prefix + "global_dist_pos_same_cam", distmat[positive_same_cam])
        tracker.append_histogram(prefix + "global_dist_pos_diff_cam", distmat[positive_diff_cam])
        tracker.append_histogram(prefix + "global_dist_pos", distmat[positive])
        tracker.append_histogram(prefix + "global_dist_neg", distmat[negative])
Esempio n. 2
0
    def forward(self, items: collections.Mapping, tracker: MetricTracker,
                storage: SharedStorage):
        features, target = items[self.output_key], items[self.target_key]
        n = len(features)

        batch_size = features.size(0)
        same_target = (target.eq(target.view(batch_size, 1)))
        norms = features.square().sum(dim=1, keepdim=True).expand(
            batch_size, batch_size)
        distmat = norms + norms.t()  # a^2 + b^2
        distmat.addmm_(beta=1, alpha=-2, mat1=features,
                       mat2=features.T)  # a^2 + b^2 - 2ab
        distmat = distmat.clamp(min=1e-12).sqrt()  # euclid

        pos_dists = distmat[same_target & ~torch.diagflat(
            torch.ones(n, dtype=torch.bool, device=distmat.device))]
        neg_dists = distmat[~same_target]
        hard_pos_dists, _ = distmat[same_target].view(batch_size,
                                                      -1).max(dim=1)
        hard_neg_dists, _ = neg_dists.view(batch_size, -1).min(dim=1)
        if self.track_distances:
            pos_mean = hard_pos_dists.mean()
            neg_mean = hard_neg_dists.mean()
            tracker.append_histogram("batch_hard_dist_ap", hard_pos_dists)
            tracker.append_histogram("batch_hard_dist_an", hard_neg_dists)
            tracker.append_histogram("batch_dist_pos", pos_dists)
            tracker.append_histogram("batch_dist_neg", neg_dists)
            tracker.append_histogram("batch_hard_delta",
                                     hard_pos_dists - hard_neg_dists)
            tracker.update("batch_hard_dist_ap_mean", pos_mean, n=n)
            tracker.update("batch_hard_dist_an_mean", neg_mean, n=n)
            tracker.update("batch_hard_dist_ap_mean",
                           pos_dists.mean().item(),
                           n=n)
            tracker.update("batch_hard_dist_an_mean",
                           neg_dists.mean().item(),
                           n=n)
            tracker.update("batch_hard_delta_mean", pos_mean - neg_mean, n=n)

        if self.margin:
            loss = self.ranking_loss(
                hard_neg_dists, hard_pos_dists,
                torch.ones(batch_size,
                           device=features.device,
                           dtype=features.dtype))
        else:
            loss = self.ranking_loss(
                hard_neg_dists - hard_pos_dists,
                torch.ones(batch_size,
                           device=features.device,
                           dtype=features.dtype))
        return loss
Esempio n. 3
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion: TrackingLoss,
                 optimizer,
                 config,
                 device,
                 data_loader,
                 train_metrics: List[ActiveMetric] = None,
                 valid_metrics: List[ActiveMetric] = None,
                 valid_data_loader=None,
                 lr_schedulers=None,
                 len_epoch=None,
                 train_storage_keys=None,
                 valid_storage_keys=None,
                 items_len_key="targets"):
        super().__init__(model, criterion, optimizer, config)
        self.config = config
        self.device = device
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch

        self.train_metrics = train_metrics if train_metrics else []
        self.valid_metrics = valid_metrics if valid_metrics else []
        assert (isinstance(self.train_metrics, collections.Collection))
        assert (isinstance(self.valid_metrics, collections.Collection))

        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None

        if lr_schedulers is None:
            self.lr_schedulers = []
        else:
            assert (isinstance(lr_schedulers, Sequence))
            self.lr_schedulers = lr_schedulers
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.train_storage_keys = train_storage_keys if train_storage_keys else []
        self.valid_storage_keys = valid_storage_keys if valid_storage_keys else []
        self.items_len_key = items_len_key

        self.tracker = MetricTracker(writer=self.writer)

    def _to_device(self, items):
        """Move data loader output (dictionary form) to appropriate device
        """
        result = dict()
        for key, value in items.items():
            if isinstance(value, torch.Tensor):
                result[key] = value.to(self.device)
            else:
                result[key] = value
        return result

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        storage = SharedStorage()
        self.model.train()
        self.tracker.reset()
        self.writer.set_step(epoch - 1)

        for batch_idx, items in enumerate(self.data_loader):
            self.optimizer.zero_grad()
            items = self._to_device(items)
            items.update(self.model(items))
            for key in self.train_storage_keys:
                storage.add_data_batch(key, items[key].detach())

            loss = self.criterion(items, self.tracker, storage)
            loss.backward()
            self.tracker.update("loss", loss.item(),
                                len(items[self.items_len_key]))
            self.optimizer.step()

            if batch_idx == 0:
                self.writer.add_image(
                    'input',
                    make_grid(items["images"].cpu(), nrow=8, normalize=True))
            for metric in self.train_metrics:
                metric.on_step_end(items, self.tracker, storage)

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

            if batch_idx == self.len_epoch:
                pass  # ignore len_epoch for hard sampling

        for scheduler in self.lr_schedulers:
            scheduler.step()

        for metric in self.train_metrics:
            metric.on_epoch_end(self.tracker, storage)
        log = self.tracker.commit()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        storage = SharedStorage()
        self.model.eval()
        self.writer.set_step(epoch - 1, 'valid')

        with torch.no_grad():
            for batch_idx, items in enumerate(self.valid_data_loader):
                items = self._to_device(items)
                items.update(self.model(items))
                for key in self.valid_storage_keys:
                    storage.add_data_batch(key, items[key].detach())

                loss = self.criterion(items, self.tracker, storage)
                self.tracker.update("loss", loss.item(),
                                    len(items[self.items_len_key]))

                for metric in self.valid_metrics:
                    metric.on_step_end(items, self.tracker, storage)

                if batch_idx == 0:
                    self.writer.add_image(
                        'input',
                        make_grid(items["images"].cpu(),
                                  nrow=8,
                                  normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.tracker.append_histogram(name, p)

        for metric in self.valid_metrics:
            metric.on_epoch_end(self.tracker, storage)
        log = self.tracker.commit()
        return log

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)