def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() assert self._model.training, self._model.training self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa for i, data in zip(indicator, self._pretrain_encoder_loader): (img, _), (img_tf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) _, *features = self._model(torch.cat([img, img_tf], dim=0), return_features=True) en = self._feature_extractor(features)[0] global_enc, global_tf_enc = torch.chunk(F.normalize(self._projection_head(en), dim=1), chunks=2, dim=0) labels = self._label_generation(partition_list, group_list) contrastive_loss = self._contrastive_criterion( torch.stack([global_enc, global_tf_enc], dim=1), labels=labels ) self._optimizer.zero_grad() contrastive_loss.backward() self._optimizer.step() # todo: meter recording. with torch.no_grad(): self.meters["contrastive_loss"].add(contrastive_loss.item()) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) return report_dict
def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() assert self._model.training, self._model.training self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) self.meters["reg_weight"].add(self._iic_weight) with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa for i, data in zip(indicator, self._pretrain_encoder_loader): (img, _), (img_tf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) _, *features = self._model(torch.cat([img, img_tf], dim=0), return_features=True) en = self._feature_extractor(features)[0] global_enc, global_tf_enc = torch.chunk(F.normalize(self._projection_head(en), dim=1), chunks=2, dim=0) # projection_classifier gives a list of probabilities global_probs, global_tf_probs = list( zip(*[torch.chunk(x, chunks=2, dim=0) for x in self._projection_classifier(en)])) # fixme: here lack of some code for IIC labels = self._label_generation(partition_list, group_list) contrastive_loss = self._contrastive_criterion(torch.stack([global_enc, global_tf_enc], dim=1), labels=labels) iic_loss_list = [self._iic_criterion(x, y)[0] for x, y in zip(global_probs, global_tf_probs)] iic_loss = average_iter(iic_loss_list) if self._disable_contrastive: total_loss = iic_loss else: total_loss = self._iic_weight * iic_loss + contrastive_loss self._optimizer.zero_grad() total_loss.backward() self._optimizer.step() # todo: meter recording. with torch.no_grad(): self.meters["contrastive_loss"].add(contrastive_loss.item()) self.meters["iic_loss"].add(iic_loss.item()) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) return report_dict
def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() assert self._model.training, self._model.training report_dict: EpochResultDict self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: for i, label_data in zip(indicator, self._labeled_loader): (labelimage, labeltarget), _, filename, partition_list, group_list \ = self._preprocess_data(label_data, self._device) predict_logits = self._model(labelimage) assert not simplex(predict_logits), predict_logits onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4) sup_loss = self._sup_criterion(predict_logits.softmax(1), onehot_ltarget) self._optimizer.zero_grad() sup_loss.backward() self._optimizer.step() with torch.no_grad(): self.meters["sup_loss"].add(sup_loss.item()) self.meters["ds"].add(predict_logits.max(1)[1], labeltarget.squeeze(1), group_name=list(group_list)) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) report_dict = self.meters.tracking_status() return report_dict
def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() assert self._model.training, self._model.training self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa for i, data in zip(indicator, self._pretrain_decoder_loader): (img, _), (img_ctf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) seed = random.randint(0, int(1e5)) with FixRandomSeed(seed): img_gtf = torch.stack([self._transformer(x) for x in img], dim=0) assert img_gtf.shape == img.shape, (img_gtf.shape, img.shape) _, *features = self._model(torch.cat([img_gtf, img_ctf], dim=0), return_features=True) dn = self._feature_extractor(features)[0] dn_gtf, dn_ctf = torch.chunk(dn, chunks=2, dim=0) with FixRandomSeed(seed): dn_ctf_gtf = torch.stack([self._transformer(x) for x in dn_ctf], dim=0) assert dn_ctf_gtf.shape == dn_ctf.shape, (dn_ctf_gtf.shape, dn_ctf.shape) dn_tf = torch.cat([dn_gtf, dn_ctf_gtf]) local_enc_tf, local_enc_tf_ctf = torch.chunk(self._projection_head(dn_tf), chunks=2, dim=0) # todo: convert representation to distance local_enc_unfold, _ = unfold_position(local_enc_tf, partition_num=(2, 2)) local_tf_enc_unfold, _fold_partition = unfold_position(local_enc_tf_ctf, partition_num=(2, 2)) b, *_ = local_enc_unfold.shape local_enc_unfold_norm = F.normalize(local_enc_unfold.view(b, -1), p=2, dim=1) local_tf_enc_unfold_norm = F.normalize(local_tf_enc_unfold.view(b, -1), p=2, dim=1) labels = self._label_generation(partition_list, group_list, _fold_partition) contrastive_loss = self._contrastive_criterion( torch.stack([local_enc_unfold_norm, local_tf_enc_unfold_norm], dim=1), labels=labels ) if torch.isnan(contrastive_loss): raise RuntimeError(contrastive_loss) self._optimizer.zero_grad() contrastive_loss.backward() self._optimizer.step() # todo: meter recording. with torch.no_grad(): self.meters["contrastive_loss"].add(contrastive_loss.item()) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) return report_dict
def _run(self, *args, **kwargs) -> EpochResultDict: self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) self._model.train() assert self._model.training, self._model.training report_dict = {} with FeatureExtractor(self._model, self._feature_position) as self._fextractor: for i, labeled_data, unlabeled_data in zip(self._indicator, self._labeled_loader, self._unlabeled_loader): seed = random.randint(0, int(1e7)) labeled_image, labeled_target, labeled_filename, _, label_group = \ self._unzip_data(labeled_data, self._device) unlabeled_image, unlabeled_target, *_ = self._unzip_data( unlabeled_data, self._device) with FixRandomSeed(seed): unlabeled_image_tf = torch.stack( [self._affine_transformer(x) for x in unlabeled_image], dim=0) assert unlabeled_image_tf.shape == unlabeled_image.shape, \ (unlabeled_image_tf.shape, unlabeled_image.shape) predict_logits = self._model( torch.cat( [labeled_image, unlabeled_image, unlabeled_image_tf], dim=0)) label_logits, unlabel_logits, unlabel_tf_logits = \ torch.split( predict_logits, [len(labeled_image), len(unlabeled_image), len(unlabeled_image_tf)], dim=0 ) with FixRandomSeed(seed): unlabel_logits_tf = torch.stack( [self._affine_transformer(x) for x in unlabel_logits], dim=0) assert unlabel_logits_tf.shape == unlabel_tf_logits.shape, \ (unlabel_logits_tf.shape, unlabel_tf_logits.shape) # supervised part onehot_target = class2one_hot(labeled_target.squeeze(1), self.num_classes) sup_loss = self._sup_criterion(label_logits.softmax(1), onehot_target) # regularized part reg_loss = self.regularization( unlabeled_tf_logits=unlabel_tf_logits, unlabeled_logits_tf=unlabel_logits_tf, seed=seed, unlabeled_image=unlabeled_image, unlabeled_image_tf=unlabeled_image_tf, ) total_loss = sup_loss + self._reg_weight * reg_loss # gradient backpropagation self._optimizer.zero_grad() total_loss.backward() self._optimizer.step() # recording can be here or in the regularization method with torch.no_grad(): self.meters["sup_loss"].add(sup_loss.item()) self.meters["sup_dice"].add(label_logits.max(1)[1], labeled_target.squeeze(1), group_name=label_group) self.meters["reg_loss"].add(reg_loss.item()) report_dict = self.meters.tracking_status() self._indicator.set_postfix_dict(report_dict) return report_dict