def on_epoch_end(self, tracker: MetricTracker, storage: SharedStorage): if self.train: qf = gf = storage.get_data("features") qpids = gpids = storage.get_data("pids") qcamids = gcamids = storage.get_data("camids") prefix = "" else: qf = storage.get_data("qf") gf = storage.get_data("gf") qpids = storage.get_data("qpids") gpids = storage.get_data("gpids") qcamids = storage.get_data("qcamids") gcamids = storage.get_data("gcamids") prefix = "valid_" distmat = storage.get_data("distmat") if distmat is None: distmat = compute_distances(qf, gf) storage.set_data("distmat", distmat) same_pid = gpids.eq(qpids.reshape(-1, 1)) same_cam = gcamids.eq(qcamids.reshape(-1, 1)) negative: torch.Tensor = ~same_pid positive: torch.Tensor = same_pid positive_same_cam = torch.logical_and(same_pid, same_cam) positive_diff_cam = torch.logical_and(same_pid, ~same_cam) if self.train: # filter out identical instances from positive distances same_image = torch.diagflat(torch.ones(qf.size(0), dtype=torch.bool, device=qf.device)) positive.logical_and_(~same_image) positive_same_cam.logical_and_(~same_image) tracker.update(prefix + "global_dist_pos_same_cam_mean", distmat[positive_same_cam].mean().item()) tracker.update(prefix + "global_dist_pos_diff_cam_mean", distmat[positive_diff_cam].mean().item()) tracker.update(prefix + "global_dist_pos_mean", distmat[positive].mean().item()) tracker.update(prefix + "global_dist_neg_mean", distmat[negative].mean().item()) tracker.append_histogram(prefix + "global_dist_pos_same_cam", distmat[positive_same_cam]) tracker.append_histogram(prefix + "global_dist_pos_diff_cam", distmat[positive_diff_cam]) tracker.append_histogram(prefix + "global_dist_pos", distmat[positive]) tracker.append_histogram(prefix + "global_dist_neg", distmat[negative])
def forward(self, items: collections.Mapping, tracker: MetricTracker, storage: SharedStorage): features, target = items[self.output_key], items[self.target_key] n = len(features) batch_size = features.size(0) same_target = (target.eq(target.view(batch_size, 1))) norms = features.square().sum(dim=1, keepdim=True).expand( batch_size, batch_size) distmat = norms + norms.t() # a^2 + b^2 distmat.addmm_(beta=1, alpha=-2, mat1=features, mat2=features.T) # a^2 + b^2 - 2ab distmat = distmat.clamp(min=1e-12).sqrt() # euclid pos_dists = distmat[same_target & ~torch.diagflat( torch.ones(n, dtype=torch.bool, device=distmat.device))] neg_dists = distmat[~same_target] hard_pos_dists, _ = distmat[same_target].view(batch_size, -1).max(dim=1) hard_neg_dists, _ = neg_dists.view(batch_size, -1).min(dim=1) if self.track_distances: pos_mean = hard_pos_dists.mean() neg_mean = hard_neg_dists.mean() tracker.append_histogram("batch_hard_dist_ap", hard_pos_dists) tracker.append_histogram("batch_hard_dist_an", hard_neg_dists) tracker.append_histogram("batch_dist_pos", pos_dists) tracker.append_histogram("batch_dist_neg", neg_dists) tracker.append_histogram("batch_hard_delta", hard_pos_dists - hard_neg_dists) tracker.update("batch_hard_dist_ap_mean", pos_mean, n=n) tracker.update("batch_hard_dist_an_mean", neg_mean, n=n) tracker.update("batch_hard_dist_ap_mean", pos_dists.mean().item(), n=n) tracker.update("batch_hard_dist_an_mean", neg_dists.mean().item(), n=n) tracker.update("batch_hard_delta_mean", pos_mean - neg_mean, n=n) if self.margin: loss = self.ranking_loss( hard_neg_dists, hard_pos_dists, torch.ones(batch_size, device=features.device, dtype=features.dtype)) else: loss = self.ranking_loss( hard_neg_dists - hard_pos_dists, torch.ones(batch_size, device=features.device, dtype=features.dtype)) return loss
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion: TrackingLoss, optimizer, config, device, data_loader, train_metrics: List[ActiveMetric] = None, valid_metrics: List[ActiveMetric] = None, valid_data_loader=None, lr_schedulers=None, len_epoch=None, train_storage_keys=None, valid_storage_keys=None, items_len_key="targets"): super().__init__(model, criterion, optimizer, config) self.config = config self.device = device self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.train_metrics = train_metrics if train_metrics else [] self.valid_metrics = valid_metrics if valid_metrics else [] assert (isinstance(self.train_metrics, collections.Collection)) assert (isinstance(self.valid_metrics, collections.Collection)) self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None if lr_schedulers is None: self.lr_schedulers = [] else: assert (isinstance(lr_schedulers, Sequence)) self.lr_schedulers = lr_schedulers self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_storage_keys = train_storage_keys if train_storage_keys else [] self.valid_storage_keys = valid_storage_keys if valid_storage_keys else [] self.items_len_key = items_len_key self.tracker = MetricTracker(writer=self.writer) def _to_device(self, items): """Move data loader output (dictionary form) to appropriate device """ result = dict() for key, value in items.items(): if isinstance(value, torch.Tensor): result[key] = value.to(self.device) else: result[key] = value return result def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ storage = SharedStorage() self.model.train() self.tracker.reset() self.writer.set_step(epoch - 1) for batch_idx, items in enumerate(self.data_loader): self.optimizer.zero_grad() items = self._to_device(items) items.update(self.model(items)) for key in self.train_storage_keys: storage.add_data_batch(key, items[key].detach()) loss = self.criterion(items, self.tracker, storage) loss.backward() self.tracker.update("loss", loss.item(), len(items[self.items_len_key])) self.optimizer.step() if batch_idx == 0: self.writer.add_image( 'input', make_grid(items["images"].cpu(), nrow=8, normalize=True)) for metric in self.train_metrics: metric.on_step_end(items, self.tracker, storage) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) if batch_idx == self.len_epoch: pass # ignore len_epoch for hard sampling for scheduler in self.lr_schedulers: scheduler.step() for metric in self.train_metrics: metric.on_epoch_end(self.tracker, storage) log = self.tracker.commit() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ storage = SharedStorage() self.model.eval() self.writer.set_step(epoch - 1, 'valid') with torch.no_grad(): for batch_idx, items in enumerate(self.valid_data_loader): items = self._to_device(items) items.update(self.model(items)) for key in self.valid_storage_keys: storage.add_data_batch(key, items[key].detach()) loss = self.criterion(items, self.tracker, storage) self.tracker.update("loss", loss.item(), len(items[self.items_len_key])) for metric in self.valid_metrics: metric.on_step_end(items, self.tracker, storage) if batch_idx == 0: self.writer.add_image( 'input', make_grid(items["images"].cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.tracker.append_histogram(name, p) for metric in self.valid_metrics: metric.on_epoch_end(self.tracker, storage) log = self.tracker.commit() return log def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)