def before_eval_epoch(self, *args, **kwargs): self.tqdm_indicator = tqdm_(range(self._val_batches), total=self._val_batches) self._epoch = kwargs.get("epoch") if self._epoch is not None: self.tqdm_indicator.set_description( f"Evaluating Epoch {self._epoch}")
def _eval_loop( self, val_loader: DataLoader = None, epoch: int = 0, mode=ModelMode.EVAL, *args, **kwargs, ) -> float: # set model mode self._model.set_mode(mode) assert self._model.torchnet.training == False, self._model.training # set tqdm-based trainer _val_loader = tqdm_(val_loader) _val_loader.set_description(f"Validating epoch {epoch}: ") for batch_id, (imgs, targets) in enumerate(_val_loader): imgs, targets = imgs.to(self._device), targets.to(self._device) preds = self._model(imgs) loss = self.ce_criterion(preds, targets) self.METERINTERFACE["val_loss"].add(loss.item()) self.METERINTERFACE["val_acc"].add(preds.max(1)[1], targets) report_dict = self._eval_report_dict _val_loader.set_postfix(report_dict) print( colored(f"Validating epoch {epoch}: {nice_dict(report_dict)}", "green")) return self.METERINTERFACE["val_acc"].summary()["acc"]
def _train_loop( self, train_loader: DataLoader, epoch: int, mode=ModelMode.TRAIN, *args, **kwargs, ): super()._train_loop(*args, **kwargs) self.model.set_mode(mode) assert self.model.training _train_loader: tqdm = tqdm_(train_loader) for _batch_num, images_labels_indices in enumerate(_train_loader): images, labels, *_ = zip(*images_labels_indices) tf1_images = torch.cat(tuple( [images[0] for _ in range(images.__len__() - 1)]), dim=0).to(self.device) tf2_images = torch.cat(tuple(images[1:]), dim=0).to(self.device) pred_tf1_simplex = self.model(tf1_images) pred_tf2_simplex = self.model(tf2_images) assert simplex(pred_tf1_simplex[0]), pred_tf1_simplex assert simplex(pred_tf2_simplex[0]), pred_tf2_simplex total_loss = self._trainer_specific_loss(tf1_images, tf2_images, pred_tf1_simplex, pred_tf2_simplex) self.model.zero_grad() total_loss.backward() self.model.step() report_dict = self._training_report_dict _train_loader.set_postfix(report_dict) report_dict_str = ", ".join( [f"{k}:{v:.3f}" for k, v in report_dict.items()]) print(f" Training epoch: {epoch} : {report_dict_str}")
def _eval_loop(self, val_loader: DataLoader, epoch: int, mode=ModelMode.EVAL, *args, **kwargs) -> float: super(IMSAT_Trainer, self)._eval_loop(*args, **kwargs) self.model.set_mode(mode) assert not self.model.training _val_loader = tqdm_(val_loader) preds = torch.zeros( self.model.arch_dict["num_sub_heads"], val_loader.dataset.__len__(), dtype=torch.long, device=self.device, ) probas = torch.zeros( self.model.arch_dict["num_sub_heads"], val_loader.dataset.__len__(), self.model.arch_dict["output_k"], dtype=torch.float, device=self.device, ) gts = torch.zeros(val_loader.dataset.__len__(), dtype=torch.long, device=self.device) _batch_done = 0 for _batch_num, images_labels_indices in enumerate(_val_loader): images, labels, *_ = zip(*images_labels_indices) images, labels = images[0].to(self.device), labels[0].to( self.device) pred = self.model(images) _bSlice = slice(_batch_done, _batch_done + images.shape[0]) gts[_bSlice] = labels for subhead in range(pred.__len__()): preds[subhead][_bSlice] = pred[subhead].max(1)[1] probas[subhead][_bSlice] = pred[subhead] _batch_done += images.shape[0] assert _batch_done == val_loader.dataset.__len__(), _batch_done # record subhead_accs = [] for subhead in range(self.model.arch_dict["num_sub_heads"]): reorder_pred, remap = hungarian_match( flat_preds=preds[subhead], flat_targets=gts, preds_k=self.model.arch_dict["output_k"], targets_k=self.model.arch_dict["output_k"], ) _acc = flat_acc(reorder_pred, gts) subhead_accs.append(_acc) # record average acc self.METERINTERFACE.val_average_acc.add(_acc) self.METERINTERFACE.val_best_acc.add(max(subhead_accs)) report_dict = self._eval_report_dict report_dict_str = ", ".join( [f"{k}:{v:.3f}" for k, v in report_dict.items()]) print(f"Validating epoch: {epoch} : {report_dict_str}") return self.METERINTERFACE.val_best_acc.summary()["mean"]
def _sup_train_loop(train_loader, epoch): self.model.train() train_loader_ = tqdm_(train_loader) for batch_num, (image_gt) in enumerate(train_loader_): image, gt = zip(*image_gt) image = image[0].to(self.device) gt = gt[0].to(self.device) if self.use_sobel: image = self.sobel(image) pred = self.model.torchnet(image)[0] loss = self.kl(pred, class2one_hot(gt, 10).float()) self.model.zero_grad() loss.backward() self.model.step() linear_meters["train_loss"].add(loss.item()) linear_meters["train_acc"].add(pred.max(1)[1], gt) report_dict = { "tra_acc": linear_meters["train_acc"].summary()["acc"], "loss": linear_meters["train_loss"].summary()["mean"], } train_loader_.set_postfix(report_dict) print(f" Training epoch {epoch}: {nice_dict(report_dict)} ")
def plot_cluster_average_images(val_loader, soft_pred): # assert val_loader.dataset_name == "mnist", \ # f"save tsne plot is only implemented for MNIST dataset, given {val_loader.dataset_name}." from deepclustering.augment.tensor_augment import Resize import warnings resize_call = Resize((24, 24), interpolation='bilinear') average_images = [torch.zeros(24, 24) for _ in range(10)] counter = 0 for image_labels in tqdm_(val_loader): images, gt, *_ = list(zip(*image_labels)) # only take the tf3 image and gts, put them to self.device images, gt = images[0].cuda(), gt[0].cuda() for i, img in enumerate(images): with warnings.catch_warnings(): warnings.simplefilter("ignore") img = resize_call(img.unsqueeze(0)) average_images[soft_pred[counter + i].argmax( )] += img.squeeze().cpu() * soft_pred[counter + i].max() counter += len(images) assert counter == val_loader.dataset.__len__() average_images = [ average_image / (counter / 10) for average_image in average_images ] return average_images
def _train_loop( self, train_loader: DataLoader = None, epoch: int = 0, mode=ModelMode.TRAIN, *args, **kwargs, ): # set model mode self.model.set_mode(mode) assert self.model.torchnet.training == True # set tqdm-based trainer self.METERINTERFACE["lr"].add(self.model.get_lr()[0]) _train_loader = tqdm_(train_loader) _train_loader.set_description( f" Training epoch {epoch}: lr={self.METERINTERFACE['lr'].summary()['value']:.5f}" ) for batch_id, (imgs, targets) in enumerate(_train_loader): imgs, targets = imgs.to(self.device), targets.to(self.device) preds = self.model(imgs) loss = self.ce_criterion(preds, targets) with ZeroGradientBackwardStep(loss, self.model) as scaled_loss: scaled_loss.backward() self.METERINTERFACE["train_loss"].add(loss.item()) self.METERINTERFACE["train_acc"].add(preds.max(1)[1], targets) report_dict = self._training_report_dict _train_loader.set_postfix(report_dict) print(colored(f" Training epoch {epoch}: {nice_dict(report_dict)}", "red"))
def _train_loop( self, labeled_loader: DataLoader = None, unlabeled_loader: DataLoader = None, epoch: int = 0, mode=ModelMode.TRAIN, *args, **kwargs, ): self._model.set_mode(mode) _max_iter = tqdm_(range(self.max_iter)) _max_iter.set_description(f"Training Epoch {epoch}") self.METERINTERFACE["lr"].add(self._model.get_lr()[0]) for batch_num, (lab_img, lab_gt), (unlab_img, unlab_gt) in zip( _max_iter, labeled_loader, unlabeled_loader): lab_img, lab_gt = lab_img.to(self._device), lab_gt.to(self._device) lab_preds = self._model(lab_img) sup_loss = self.kl_criterion( lab_preds, class2one_hot(lab_gt, C=self._model.torchnet.num_classes).float(), ) reg_loss = self._trainer_specific_loss(unlab_img, unlab_gt) self.METERINTERFACE["traloss"].add(sup_loss.item()) self.METERINTERFACE["traconf"].add(lab_preds.max(1)[1], lab_gt) with ZeroGradientBackwardStep(sup_loss + reg_loss, self._model) as total_loss: total_loss.backward() report_dict = self._training_report_dict _max_iter.set_postfix(report_dict) print(f"Training Epoch {epoch}: {nice_dict(report_dict)}") self.writer.add_scalar_with_tag("train", report_dict, global_step=epoch)
def _eval_loop( self, val_loader: DataLoader = None, epoch: int = 0, mode=ModelMode.EVAL, *args, **kwargs, ) -> float: self._model.set_mode(mode) _val_loader = tqdm_(val_loader) _val_loader.set_description(f"Validating Epoch {epoch}") for batch_num, (val_img, val_gt) in enumerate(_val_loader): val_img, val_gt = val_img.to(self._device), val_gt.to(self._device) val_preds = self._model(val_img) val_loss = self.kl_criterion( val_preds, class2one_hot(val_gt, C=self._model.torchnet.num_classes).float(), disable_assert=True, ) self.METERINTERFACE["valloss"].add(val_loss.item()) self.METERINTERFACE["valconf"].add(val_preds.max(1)[1], val_gt) report_dict = self._eval_report_dict _val_loader.set_postfix(report_dict) print(f"Validating Epoch {epoch}: {nice_dict(report_dict)}") self.writer.add_scalar_with_tag(tag="eval", tag_scalar_dict=report_dict, global_step=epoch) return self.METERINTERFACE["valconf"].summary()["acc"]
def _eval_loop( self, val_loader: DataLoader = None, epoch: int = 0, mode=ModelMode.EVAL, *args, **kwargs, ) -> float: self.model.set_mode(mode) _val_loader = tqdm_(val_loader) _val_loader.set_description(f"Validating Epoch {epoch}") for batch_num, ((val_img, val_gt), val_path) in enumerate(_val_loader): val_img, val_gt = val_img.to(self.device), val_gt.to(self.device) val_preds = self.model(val_img, force_simplex=True) val_loss = self.kl_criterion( val_preds, class2one_hot(val_gt.squeeze(1), C=self.model.arch_dict["num_classes"]).float(), disable_assert=True, ) self.METERINTERFACE["valloss"].add(val_loss.item()) self.METERINTERFACE["valdice"].add(val_preds, val_gt) self.METERINTERFACE["valbdice"].add(val_preds, val_gt) report_dict = self._eval_report_dict _val_loader.set_postfix(report_dict) print(f"Validating Epoch {epoch}: {nice_dict(report_dict)}") self.writer.add_scalar_with_tag(tag="eval", tag_scalar_dict=report_dict, global_step=epoch) return self.METERINTERFACE["valbdice"].value()[0][0].item()
def _eval_loop( self, val_loader: DataLoader = None, epoch: int = 0, mode: ModelMode = ModelMode.EVAL, **kwargs, ) -> float: self.model.set_mode(mode) assert ( not self.model.training ), f"Model should be in eval model in _eval_loop, given {self.model.training}." val_loader_: tqdm = tqdm_(val_loader) preds = torch.zeros( self.model.arch_dict["num_sub_heads"], val_loader.dataset.__len__(), dtype=torch.long, device=self.device, ) target = torch.zeros(val_loader.dataset.__len__(), dtype=torch.long, device=self.device) slice_done = 0 subhead_accs = [] val_loader_.set_description(f"Validating epoch: {epoch}") for batch, image_labels in enumerate(val_loader_): images, gt, *_ = list(zip(*image_labels)) images, gt = images[0].to(self.device), gt[0].to(self.device) _pred = self.model.torchnet(images) assert (assert_list(simplex, _pred) and _pred.__len__() == self.model.arch_dict["num_sub_heads"]) bSlicer = slice(slice_done, slice_done + images.shape[0]) for subhead in range(self.model.arch_dict["num_sub_heads"]): preds[subhead][bSlicer] = _pred[subhead].max(1)[1] target[bSlicer] = gt slice_done += gt.shape[0] assert slice_done == val_loader.dataset.__len__( ), "Slice not completed." for subhead in range(self.model.arch_dict["num_sub_heads"]): reorder_pred, remap = hungarian_match( flat_preds=preds[subhead], flat_targets=target, preds_k=self.model.arch_dict["output_k_B"], targets_k=self.model.arch_dict["output_k_B"], ) _acc = flat_acc(reorder_pred, target) subhead_accs.append(_acc) # record average acc self.METERINTERFACE.val_avg_acc.add(_acc) # record best acc self.METERINTERFACE.val_best_acc.add(max(subhead_accs)) self.METERINTERFACE.val_worst_acc.add(min(subhead_accs)) report_dict = self._eval_report_dict report_dict_str = ", ".join( [f"{k}:{v:.3f}" for k, v in report_dict.items()]) print(f"Validating epoch: {epoch} : {report_dict_str}") return self.METERINTERFACE.val_best_acc.summary()["mean"]
def _train_loop(self, train_loader=None, epoch=0, mode: ModelMode = ModelMode.TRAIN, **kwargs): self.model.set_mode(mode) assert ( self.model.training ), f"Model should be in train() model, given {self.model.training}." train_loader_: tqdm = tqdm_(train_loader) train_loader_.set_description(f"Training epoch: {epoch}") for batch, image_labels in enumerate(train_loader_): images, _, (index, *_) = list(zip(*image_labels)) tf1_images = torch.cat( [images[0] for _ in range(images.__len__() - 1)], dim=0).to(self.device) tf2_images = torch.cat(images[1:], dim=0).to(self.device).to(self.device) index = torch.cat([index for _ in range(images.__len__() - 1)], dim=0) assert tf1_images.shape == tf2_images.shape tf1_pred_logit = self.model.torchnet(tf1_images) tf2_pred_logit = self.model.torchnet(tf2_images) assert (assert_list(simplex, tf1_pred_logit) and tf1_pred_logit[0].shape == tf2_pred_logit[0].shape) sat_losses = [] ml_losses = [] for subhead_num, (tf1_pred, tf2_pred) in enumerate( zip(tf1_pred_logit, tf2_pred_logit)): sat_loss = self.SAT_criterion(tf2_pred, tf1_pred.detach()) ml_loss, *_ = self.MI_criterion(tf1_pred) # sat_losses.append(sat_loss) ml_losses.append(ml_loss) ml_losses = sum(ml_losses) / len(ml_losses) # sat_losses = sum(sat_losses) / len(sat_losses) # VAT_generator = VATLoss_Multihead(eps=self.nearest_dict[index]) VAT_generator = VATLoss_Multihead(eps=10) vat_loss, adv_tf1_images, _ = VAT_generator( self.model.torchnet, tf1_images) batch_loss: torch.Tensor = vat_loss - 0.1 * ml_losses # self.METERINTERFACE["train_sat_loss"].add(sat_losses.item()) self.METERINTERFACE["train_mi_loss"].add(ml_losses.item()) self.METERINTERFACE["train_adv_loss"].add(vat_loss.item()) self.model.zero_grad() batch_loss.backward() self.model.step() report_dict = self._training_report_dict train_loader_.set_postfix(report_dict)
def _linear_eval_loop(val_loader, epoch) -> Tensor: val_loader_ = tqdm_(val_loader) for batch_num, (feature, gt) in enumerate(val_loader_): feature, gt = feature.to(self.device), gt.to(self.device) pred = linearnet(feature) linear_meters["val_acc"].add(pred.max(1)[1], gt) report_dict = { "val_acc": linear_meters["val_acc"].summary()["acc"] } val_loader_.set_postfix(report_dict) print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ") return linear_meters["val_acc"].summary()["acc"]
def _train_loop( self, train_loader=None, epoch=0, mode=ModelMode.TRAIN, *args, **kwargs ): self.model.train() train_loader_: tqdm = tqdm_(train_loader) for batch_num, data in enumerate(train_loader_): img, _ = data img = img.to(self.device) # ===================forward===================== output = self.model(img) loss = self.criterion(output, img) # ===================backward==================== self.model.zero_grad() loss.backward() self.model.step() self.METERINTERFACE.rec_loss.add(loss.item()) train_loader_.set_postfix(self._training_report_dict())
def _linear_train_loop(train_loader, epoch): train_loader_ = tqdm_(train_loader) for batch_num, (feature, gt) in enumerate(train_loader_): feature, gt = feature.to(self.device), gt.to(self.device) pred = linearnet(feature) loss = self.criterion(pred, gt) linearOptim.zero_grad() loss.backward() linearOptim.step() linear_meters["train_loss"].add(loss.item()) linear_meters["train_acc"].add(pred.max(1)[1], gt) report_dict = { "tra_acc": linear_meters["train_acc"].summary()["acc"], "loss": linear_meters["train_loss"].summary()["mean"], } train_loader_.set_postfix(report_dict) print(f" Training epoch {epoch}: {nice_dict(report_dict)} ")
def _eval_loop(self, val_loader: DataLoader = None, epoch: int = 0, mode=ModelMode.EVAL, *args, **kwargs) -> float: self.model.set_mode(mode) assert not self.model.training val_loader_: tqdm = tqdm_(val_loader) for _batch_num, (img, label) in enumerate(val_loader_): img, label = img.to(self.device), label.to(self.device) pred, _ = self.model(img) self.METERINTERFACE.val_conf.add(pred.max(1)[1], label) report_dict = self._eval_report_dict val_loader_.set_postfix(report_dict) print(f'Validating epoch {epoch}: {nice_dict(report_dict)}') return self.METERINTERFACE.val_conf.summary()['acc']
def _sup_eval_loop(val_loader, epoch) -> Tensor: self.model.eval() val_loader_ = tqdm_(val_loader) for batch_num, (image_gt) in enumerate(val_loader_): image, gt = zip(*image_gt) image = image[0].to(self.device) gt = gt[0].to(self.device) if self.use_sobel: image = self.sobel(image) pred = self.model.torchnet(image)[0] linear_meters["val_acc"].add(pred.max(1)[1], gt) report_dict = { "val_acc": linear_meters["val_acc"].summary()["acc"] } val_loader_.set_postfix(report_dict) print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ") return linear_meters["val_acc"].summary()["acc"]
def training(self): x1, y = make_classification(1000, n_features=10, n_informative=5, n_classes=10) x1 = torch.from_numpy(x1).cuda().float() y = torch.from_numpy(y).cuda().long() itera: tqdm = tqdm_(range(100000)) for i in itera: noise = torch.randn_like(x1).cuda() x2 = x1 + 0.1 * noise p1 = self.model(x1) p2 = self.model(x2) loss = self._loss_function(x1, p1, x2, p2) reordered, _ = hungarian_match(p1.max(1)[1], y, 10, 10) print(reordered.unique()) acc = flat_acc(y, reordered) acc2 = flat_acc(y, p1.max(1)[1]) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if i % 10 == 0: self.show(p1, p2) itera.set_postfix({"loss": loss.item(), "acc": acc, "acc2": acc2})
def _train_loop(self, labeled_loader: DataLoader = None, unlabeled_loader: DataLoader = None, epoch: int = 0, mode=ModelMode.TRAIN, *args, **kwargs): super(AdaNetTrainer, self)._train_loop(*args, **kwargs) # warnings self.model.set_mode(mode) assert self.model.training labeled_loader_ = DataIter(labeled_loader) unlabeled_loader_ = DataIter(unlabeled_loader) batch_num: tqdm = tqdm_(range(unlabeled_loader.__len__())) for _batch_num, ((label_img, label_gt), (unlabel_img, _), _) in enumerate( zip(labeled_loader_, unlabeled_loader_, batch_num)): label_img, label_gt, unlabel_img = label_img.to(self.device), \ label_gt.to(self.device), unlabel_img.to(self.device) label_pred, _ = self.model(label_img) self.METERINTERFACE.tra_conf.add(label_pred.max(1)[1], label_gt) sup_loss = self.ce_loss(label_pred, label_gt.squeeze()) self.METERINTERFACE.tra_sup_label.add(sup_loss.item()) reg_loss = self._trainer_specific_loss(label_img, label_gt, unlabel_img) self.METERINTERFACE.tra_reg_total.add(reg_loss.item()) with ZeroGradientBackwardStep(sup_loss + reg_loss, self.model) as loss: loss.backward() report_dict = self._training_report_dict batch_num.set_postfix(report_dict) print(f' Training epoch {epoch}: {nice_dict(report_dict)}')
def _train_loop( self, train_loader_A: DataLoader = None, train_loader_B: DataLoader = None, epoch: int = None, mode: ModelMode = ModelMode.TRAIN, head_control_param: OrderedDict = None, *args, **kwargs, ) -> None: """ :param train_loader_A: :param train_loader_B: :param epoch: :param mode: :param head_control_param: :param args: :param kwargs: :return: None """ # robustness asserts assert isinstance(train_loader_B, DataLoader) and isinstance( train_loader_A, DataLoader) assert (head_control_param and head_control_param.__len__() > 0), \ f"`head_control_param` must be provided, given {head_control_param}." assert set(head_control_param.keys()) <= {"A", "B", }, \ f"`head_control_param` key must be in `A` or `B`, given {set(head_control_param.keys())}" for k, v in head_control_param.items(): assert k in ("A", "B"), ( f"`head_control_param` key must be in `A` or `B`," f" given{set(head_control_param.keys())}") assert isinstance( v, int) and v >= 0, f"Iteration for {k} must be >= 0." # set training mode self.model.set_mode(mode) assert ( self.model.training ), f"Model should be in train() model, given {self.model.training}." assert len(train_loader_B) == len(train_loader_A), ( f'The length of the train_loaders should be the same,"' f"given `len(train_loader_A)`:{len(train_loader_A)} and `len(train_loader_B)`:{len(train_loader_B)}." ) for head_name, head_iterations in head_control_param.items(): assert head_name in ("A", "B"), head_name train_loader = eval(f"train_loader_{head_name}" ) # change the dataset for different head for head_epoch in range(head_iterations): # given one head, one iteration in this head, and one train_loader. train_loader_: tqdm = tqdm_( train_loader) # reinitialize the train_loader train_loader_.set_description( f"Training epoch: {epoch} head:{head_name}, head_epoch:{head_epoch + 1}/{head_iterations}" ) for batch, image_labels in enumerate(train_loader_): images, *_ = list(zip(*image_labels)) # extract tf1_images, tf2_images and put then to self.device tf1_images = torch.cat(tuple( [images[0] for _ in range(len(images) - 1)]), dim=0).to(self.device) tf2_images = torch.cat(tuple(images[1:]), dim=0).to(self.device) assert tf1_images.shape == tf2_images.shape, f"`tf1_images` should have the same size as `tf2_images`," \ f"given {tf1_images.shape} and {tf2_images.shape}." # if images are processed with sobel filters if self.use_sobel: tf1_images = self.sobel(tf1_images) tf2_images = self.sobel(tf2_images) assert tf1_images.shape == tf2_images.shape # Here you have two kinds of geometric transformations # todo: functions to be overwritten batch_loss = self._trainer_specific_loss( tf1_images, tf2_images, head_name) # update model with self-defined context manager support Apex module with ZeroGradientBackwardStep(batch_loss, self.model) as loss: loss.backward() # write value to tqdm module for system monitoring report_dict = self._training_report_dict train_loader_.set_postfix(report_dict) # for tensorboard recording self.writer.add_scalar_with_tag("train", report_dict, epoch) # for std recording print(f"Training epoch: {epoch} : {nice_dict(report_dict)}")
def _eval_loop( self, val_loader: DataLoader = None, epoch: int = 0, mode: ModelMode = ModelMode.EVAL, return_soft_predict=False, *args, **kwargs, ) -> float: assert isinstance( val_loader, DataLoader) # make sure a validation loader is passed. self.model.set_mode(mode) # set model to be eval mode, by default. # make sure the model is in eval mode. assert ( not self.model.training ), f"Model should be in eval model in _eval_loop, given {self.model.training}." val_loader_: tqdm = tqdm_(val_loader) # prediction initialization with shape: (num_sub_heads, num_samples) preds = torch.zeros(self.model.arch_dict["num_sub_heads"], val_loader.dataset.__len__(), dtype=torch.long, device=self.device) # soft_prediction initialization with shape (num_sub_heads, num_sample, num_classes) if return_soft_predict: soft_preds = torch.zeros( self.model.arch_dict["num_sub_heads"], val_loader.dataset.__len__(), self.model.arch_dict["output_k_B"], dtype=torch.float, device=torch.device("cpu")) # I put it into cpu # target initialization with shape: (num_samples) target = torch.zeros(val_loader.dataset.__len__(), dtype=torch.long, device=self.device) # begin index slice_done = 0 subhead_accs = [] val_loader_.set_description(f"Validating epoch: {epoch}") for batch, image_labels in enumerate(val_loader_): images, gt, *_ = list(zip(*image_labels)) # only take the tf3 image and gts, put them to self.device images, gt = images[0].to(self.device), gt[0].to(self.device) # if use sobel filter if self.use_sobel: images = self.sobel(images) # using default head_B for inference, _pred should be a list of simplex by default. _pred = self.model.torchnet(images, head="B") assert assert_list(simplex, _pred), "pred should be a list of simplexes." assert _pred.__len__() == self.model.arch_dict["num_sub_heads"] # slice window definition bSlicer = slice(slice_done, slice_done + images.shape[0]) for subhead in range(self.model.arch_dict["num_sub_heads"]): # save predictions for each subhead for each batch preds[subhead][bSlicer] = _pred[subhead].max(1)[1] if return_soft_predict: soft_preds[subhead][bSlicer] = _pred[subhead] # save target for each batch target[bSlicer] = gt # update slice index slice_done += gt.shape[0] # make sure that all the dataset has been done. Errors will raise if dataloader.drop_last=True assert slice_done == val_loader.dataset.__len__( ), "Slice not completed." for subhead in range(self.model.arch_dict["num_sub_heads"]): # remap pred for each head and compare with target to get subhead_acc reorder_pred, remap = hungarian_match( flat_preds=preds[subhead], flat_targets=target, preds_k=self.model.arch_dict["output_k_B"], targets_k=self.model.arch_dict["output_k_B"], ) _acc = flat_acc(reorder_pred, target) subhead_accs.append(_acc) # record average acc self.METERINTERFACE.val_average_acc.add(_acc) if return_soft_predict: soft_preds[subhead][:, list(remap.values( ))] = soft_preds[subhead][:, list(remap.keys())] assert torch.allclose(soft_preds[subhead].max(1)[1], reorder_pred.cpu()) # record best acc self.METERINTERFACE.val_best_acc.add(max(subhead_accs)) # record worst acc self.METERINTERFACE.val_worst_acc.add(min(subhead_accs)) report_dict = self._eval_report_dict # record results for std print(f"Validating epoch: {epoch} : {nice_dict(report_dict)}") # record results for tensorboard self.writer.add_scalar_with_tag("val", report_dict, epoch) # using multithreads to call histogram interface of tensorboard. pred_histgram(self.writer, preds, epoch=epoch) # return the current score to save the best checkpoint. if return_soft_predict: return self.METERINTERFACE.val_best_acc.summary()["mean"], ( target.cpu(), soft_preds[np.argmax(subhead_accs)] ) # type ignore return self.METERINTERFACE.val_best_acc.summary()["mean"]