class Trainer: def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() # Define Dataloader kwargs = {"num_workers": args.workers, "pin_memory": True} ( self.train_loader, self.val_loader, _, self.nclass, ) = make_data_loader(args, **kwargs) # Define network model = DeepLab( num_classes=self.nclass, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn, imagenet_pretrained_path=args.imagenet_pretrained_path, ) train_params = [ { "params": model.get_1x_lr_params(), "lr": args.lr }, { "params": model.get_10x_lr_params(), "lr": args.lr * 10 }, ] # Define Optimizer optimizer = torch.optim.SGD( train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov, ) # Define Criterion # whether to use class balanced weights if args.use_balanced_weights: classes_weights_path = ( DATASETS_DIRS[args.dataset] / args.dataset + "_classes_weights.npy") if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.model, self.optimizer = model, optimizer # Define Evaluator self.evaluator = Evaluator(self.nclass, args.seen_classes_idx_metric, args.unseen_classes_idx_metric) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError( f"=> no checkpoint found at '{args.resume}'") checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] if args.random_last_layer: checkpoint["state_dict"][ "decoder.pred_conv.weight"] = torch.rand(( self.nclass, checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[1], checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[2], checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[3], )) checkpoint["state_dict"][ "decoder.pred_conv.bias"] = torch.rand(self.nclass) if args.nonlinear_last_layer: if args.cuda: self.model.module.deeplab.load_state_dict( checkpoint["state_dict"]) else: self.model.deeplab.load_state_dict( checkpoint["state_dict"]) else: if args.cuda: self.model.module.load_state_dict(checkpoint["state_dict"]) else: self.model.load_state_dict(checkpoint["state_dict"]) if not args.ft: if not args.nonlinear_last_layer: self.optimizer.load_state_dict(checkpoint["optimizer"]) self.best_pred = checkpoint["best_pred"] print( f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" ) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def validation(self, epoch, args): self.model.eval() self.evaluator.reset() all_target = [] all_pred = [] tbar = tqdm(self.val_loader, desc="\r") test_loss = 0.0 for i, sample in enumerate(tbar): image, target = sample["image"], sample["label"] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): if args.nonlinear_last_layer: output = self.model(image, image.size()[2:]) else: output = self.model(image) loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) all_target.append(target) all_pred.append(pred) # Fast test during the training Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy() ( Acc_class, Acc_class_by_class, Acc_class_seen, Acc_class_unseen, ) = self.evaluator.Pixel_Accuracy_Class() ( mIoU, mIoU_by_class, mIoU_seen, mIoU_unseen, ) = self.evaluator.Mean_Intersection_over_Union() ( FWIoU, FWIoU_seen, FWIoU_unseen, ) = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar("val_overall/total_loss_epoch", test_loss, epoch) self.writer.add_scalar("val_overall/mIoU", mIoU, epoch) self.writer.add_scalar("val_overall/Acc", Acc, epoch) self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch) self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch) self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch) self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch) self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch) self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch) self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch) self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch) self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch) self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch) print("Validation:") print("[Epoch: %d, numImages: %5d]" % (epoch, i * self.args.batch_size + image.data.shape[0])) print(f"Loss: {test_loss:.3f}") print( f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}" ) print("Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen)) print("Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen)) for class_name, acc_value, mIoU_value in zip(CLASSES_NAMES, Acc_class_by_class, mIoU_by_class): self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value)
class Trainer: def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() """ Get dataLoader """ # config = get_config(args.config) # vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split(config) # assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1) # print('seen_classes', vals_cls) # print('novel_classes', valu_cls) # print('all_labels', all_labels) # print('visible_classes', visible_classes) # print('visible_classes_test', visible_classes_test) # print('train', train[:10], len(train)) # print('val', val[:10], len(val)) # print('cls_map', cls_map) # print('cls_map_test', cls_map_test) # Define Dataloader kwargs = {"num_workers": args.workers, "pin_memory": True} ( self.train_loader, self.val_loader, _, self.nclass, ) = make_data_loader(args, load_embedding=args.load_embedding, w2c_size=args.w2c_size, **kwargs) print('self.nclass', self.nclass) # 33 model = DeepLab( num_classes=self.nclass, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn, global_avg_pool_bn=args.global_avg_pool_bn, imagenet_pretrained_path=args.imagenet_pretrained_path, ) train_params = [ { "params": model.get_1x_lr_params(), "lr": args.lr }, { "params": model.get_10x_lr_params(), "lr": args.lr * 10 }, ] # Define Optimizer optimizer = torch.optim.SGD( train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov, ) # Define Generator generator = GMMNnetwork(args.noise_dim, args.embed_dim, args.hidden_size, args.feature_dim) optimizer_generator = torch.optim.Adam(generator.parameters(), lr=args.lr_generator) class_weight = torch.ones(self.nclass) class_weight[args.unseen_classes_idx_metric] = args.unseen_weight if args.cuda: class_weight = class_weight.cuda() self.criterion = SegmentationLosses( weight=class_weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.model, self.optimizer = model, optimizer self.criterion_generator = GMMNLoss(sigma=[2, 5, 10, 20, 40, 80], cuda=args.cuda).build_loss() self.generator, self.optimizer_generator = generator, optimizer_generator # Define Evaluator self.evaluator = Evaluator(self.nclass, args.seen_classes_idx_metric, args.unseen_classes_idx_metric) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() self.generator = self.generator.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError( f"=> no checkpoint found at '{args.resume}'") checkpoint = torch.load(args.resume) # args.start_epoch = checkpoint['epoch'] if args.random_last_layer: checkpoint["state_dict"][ "decoder.pred_conv.weight"] = torch.rand(( self.nclass, checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[1], checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[2], checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[3], )) checkpoint["state_dict"][ "decoder.pred_conv.bias"] = torch.rand(self.nclass) if args.cuda: self.model.module.load_state_dict(checkpoint["state_dict"]) else: self.model.load_state_dict(checkpoint["state_dict"]) # self.best_pred = checkpoint['best_pred'] print( f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" ) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def training(self, epoch, args): train_loss = 0.0 self.model.train() tbar = tqdm(self.train_loader) num_img_tr = len(self.train_loader) for i, sample in enumerate(tbar): if len(sample["image"]) > 1: image, target, embedding = ( sample["image"], sample["label"], sample["label_emb"], ) if self.args.cuda: image, target, embedding = ( image.cuda(), target.cuda(), embedding.cuda(), ) self.scheduler(self.optimizer, i, epoch, self.best_pred) # ===================real feature extraction===================== with torch.no_grad(): real_features = self.model.module.forward_before_class_prediction( image) # ===================fake feature generation===================== fake_features = torch.zeros(real_features.shape) if args.cuda: fake_features = fake_features.cuda() generator_loss_batch = 0.0 for ( count_sample_i, (real_features_i, target_i, embedding_i), ) in enumerate(zip(real_features, target, embedding)): generator_loss_sample = 0.0 ## reduce to real feature size real_features_i = (real_features_i.permute( 1, 2, 0).contiguous().view((-1, args.feature_dim))) target_i = nn.functional.interpolate( target_i.view(1, 1, target_i.shape[0], target_i.shape[1]), size=(real_features.shape[2], real_features.shape[3]), mode="nearest", ).view(-1) embedding_i = nn.functional.interpolate( embedding_i.view( 1, embedding_i.shape[0], embedding_i.shape[1], embedding_i.shape[2], ), size=(real_features.shape[2], real_features.shape[3]), mode="nearest", ) embedding_i = (embedding_i.permute(0, 2, 3, 1).contiguous().view( (-1, args.embed_dim))) fake_features_i = torch.zeros(real_features_i.shape) if args.cuda: fake_features_i = fake_features_i.cuda() unique_class = torch.unique(target_i) ## test if image has unseen class pixel, if yes means no training for generator and generated features for the whole image has_unseen_class = False for u_class in unique_class: if u_class in args.unseen_classes_idx_metric: has_unseen_class = True for idx_in in unique_class: if idx_in != 255: self.optimizer_generator.zero_grad() idx_class = target_i == idx_in real_features_class = real_features_i[idx_class] embedding_class = embedding_i[idx_class] z = torch.rand( (embedding_class.shape[0], args.noise_dim)) if args.cuda: z = z.cuda() fake_features_class = self.generator( embedding_class, z.float()) if (idx_in in args.seen_classes_idx_metric and not has_unseen_class): ## in order to avoid CUDA out of memory random_idx = torch.randint( low=0, high=fake_features_class.shape[0], size=(args.batch_size_generator, ), ) g_loss = self.criterion_generator( fake_features_class[random_idx], real_features_class[random_idx], ) generator_loss_sample += g_loss.item() g_loss.backward() self.optimizer_generator.step() fake_features_i[ idx_class] = fake_features_class.clone() generator_loss_batch += generator_loss_sample / len( unique_class) if args.real_seen_features and not has_unseen_class: fake_features[count_sample_i] = real_features_i.view(( fake_features.shape[2], fake_features.shape[3], args.feature_dim, )).permute(2, 0, 1) else: fake_features[count_sample_i] = fake_features_i.view(( fake_features.shape[2], fake_features.shape[3], args.feature_dim, )).permute(2, 0, 1) # ===================classification===================== self.optimizer.zero_grad() output = self.model.module.forward_class_prediction( fake_features.detach(), image.size()[2:]) loss = self.criterion(output, target) loss.backward() self.optimizer.step() train_loss += loss.item() # ===================log===================== tbar.set_description(f" G loss: {generator_loss_batch:.3f}" + " C loss: %.3f" % (train_loss / (i + 1))) self.writer.add_scalar("train/total_loss_iter", loss.item(), i + num_img_tr * epoch) self.writer.add_scalar("train/generator_loss", generator_loss_batch, i + num_img_tr * epoch) # Show 10 * 3 inference results each epoch if i % (num_img_tr // 10) == 0: global_step = i + num_img_tr * epoch self.summary.visualize_image( self.writer, self.args.dataset, image, target, output, global_step, ) self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch) print("[Epoch: %d, numImages: %5d]" % (epoch, i * self.args.batch_size + image.data.shape[0])) print(f"Loss: {train_loss:.3f}") if self.args.no_val: # save checkpoint every epoch is_best = False self.saver.save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "best_pred": self.best_pred, }, is_best, ) def validation(self, epoch, args): self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc="\r") test_loss = 0.0 saved_images = {} saved_target = {} saved_prediction = {} for idx_unseen_class in args.unseen_classes_idx_metric: saved_images[idx_unseen_class] = [] saved_target[idx_unseen_class] = [] saved_prediction[idx_unseen_class] = [] targets, outputs = [], [] log_file = './logs_context_step_2_GMMN.txt' logger = logWritter(log_file) for i, sample in enumerate(tbar): image, target, embedding = ( sample["image"], sample["label"], sample["label_emb"], ) if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output = self.model(image) loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) ## save image for tensorboard for idx_unseen_class in args.unseen_classes_idx_metric: if len((target.reshape(-1) == idx_unseen_class).nonzero()) > 0: if len(saved_images[idx_unseen_class] ) < args.saved_validation_images: saved_images[idx_unseen_class].append( image.clone().cpu()) saved_target[idx_unseen_class].append( target.clone().cpu()) saved_prediction[idx_unseen_class].append( output.clone().cpu()) pred = output.data.cpu().numpy() target = target.cpu().numpy().astype(np.int64) pred = np.argmax(pred, axis=1) for o, t in zip(pred, target): outputs.append(o) targets.append(t) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) config = get_config(args.config) vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split( config) assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1) score, class_iou = scores_gzsl(targets, outputs, n_class=len(visible_classes_test), seen_cls=cls_map_test[vals_cls], unseen_cls=cls_map_test[valu_cls]) print("Test results:") logger.write("Test results:") for k, v in score.items(): print(k + ': ' + json.dumps(v)) logger.write(k + ': ' + json.dumps(v)) score["Class IoU"] = {} visible_classes_test = sorted(visible_classes_test) for i in range(len(visible_classes_test)): score["Class IoU"][all_labels[ visible_classes_test[i]]] = class_iou[i] print("Class IoU: " + json.dumps(score["Class IoU"])) logger.write("Class IoU: " + json.dumps(score["Class IoU"])) print("Test finished.\n\n") logger.write("Test finished.\n\n") # Fast test during the training Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy() ( Acc_class, Acc_class_by_class, Acc_class_seen, Acc_class_unseen, ) = self.evaluator.Pixel_Accuracy_Class() ( mIoU, mIoU_by_class, mIoU_seen, mIoU_unseen, ) = self.evaluator.Mean_Intersection_over_Union() ( FWIoU, FWIoU_seen, FWIoU_unseen, ) = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar("val_overall/total_loss_epoch", test_loss, epoch) self.writer.add_scalar("val_overall/mIoU", mIoU, epoch) self.writer.add_scalar("val_overall/Acc", Acc, epoch) self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch) self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch) self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch) self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch) self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch) self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch) self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch) self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch) self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch) self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch) print("Validation:") print("[Epoch: %d, numImages: %5d]" % (epoch, i * self.args.batch_size + image.data.shape[0])) print(f"Loss: {test_loss:.3f}") print( f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}" ) print("Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen)) print("Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen)) for class_name, acc_value, mIoU_value in zip(CLASSES_NAMES, Acc_class_by_class, mIoU_by_class): self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value) new_pred = mIoU_unseen is_best = True self.best_pred = new_pred self.saver.save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "best_pred": self.best_pred, }, is_best, generator_state={ "epoch": epoch + 1, "state_dict": self.generator.state_dict(), "optimizer": self.optimizer.state_dict(), "best_pred": self.best_pred, }, ) global_step = epoch + 1 for idx_unseen_class in args.unseen_classes_idx_metric: if len(saved_images[idx_unseen_class]) > 0: nb_image = len(saved_images[idx_unseen_class]) if nb_image > args.saved_validation_images: nb_image = args.saved_validation_images for i in range(nb_image): self.summary.visualize_image_validation( self.writer, self.args.dataset, saved_images[idx_unseen_class][i], saved_target[idx_unseen_class][i], saved_prediction[idx_unseen_class][i], global_step, name="validation_" + CLASSES_NAMES[idx_unseen_class] + "_" + str(i), nb_image=1, ) self.evaluator.reset()
class Trainer(BaseTrainer): def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() """ Get dataLoader """ # config = get_config(args.config) # vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split(config) # assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1) # print('seen_classes', vals_cls) # print('novel_classes', valu_cls) # print('all_labels', all_labels) # print('visible_classes', visible_classes) # print('visible_classes_test', visible_classes_test) # print('train', train[:10], len(train)) # print('val', val[:10], len(val)) # print('cls_map', cls_map) # print('cls_map_test', cls_map_test) kwargs = {"num_workers": args.workers, "pin_memory": True} ( self.train_loader, self.val_loader, _, self.nclass, ) = make_data_loader(args, **kwargs) print('self.nclass', self.nclass) # Define network model = DeepLab( num_classes=self.nclass, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=False, pretrained=args.imagenet_pretrained, imagenet_pretrained_path=args.imagenet_pretrained_path, ) train_params = [ { "params": model.get_1x_lr_params(), "lr": args.lr }, { "params": model.get_10x_lr_params(), "lr": args.lr * 10 }, ] # Define Optimizer optimizer = torch.optim.SGD( train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov, ) # Define Criterion # whether to use class balanced weights if args.use_balanced_weights: classes_weights_path = ( DATASETS_DIRS[args.dataset] / args.dataset + "_classes_weights.npy") if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.model, self.optimizer = model, optimizer if args.imagenet_pretrained_path is not None: state_dict = torch.load(args.imagenet_pretrained_path) if 'state_dict' in state_dict.keys(): self.model.load_state_dict(state_dict['state_dict']) else: #print(model.state_dict().keys())#['scale.layer1.conv1.conv.weight']) #print(state_dict.items().keys()) new_dict = {} for k, v in state_dict.items(): #print(k[11:]) new_dict[k[11:]] = v self.model.load_state_dict( new_dict, strict=False ) # make strict=True to debug if checkpoint is loaded correctly or not if performance is low # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError( f"=> no checkpoint found at '{args.resume}'") checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] if args.cuda: self.model.module.load_state_dict(checkpoint["state_dict"]) else: self.model.load_state_dict(checkpoint["state_dict"]) if not args.ft: self.optimizer.load_state_dict(checkpoint["optimizer"]) self.best_pred = checkpoint["best_pred"] print( f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" ) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def validation(self, epoch, args): class_names = CLASSES_NAMES[:20] self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc="\r") test_loss = 0.0 torch.set_printoptions(profile="full") targets, outputs = [], [] log_file = './logs_voc12_step_1.txt' logger = logWritter(log_file) for i, sample in enumerate(tbar): image, target = sample["image"], sample["label"] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output = self.model(image) target = resize_target(target, s=output.size()[2:]).cuda() loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # print('pred', pred[:, 100:105, 100:120]) # print('target', target[:, 100:105, 100:120]) for o, t in zip(pred, target): outputs.append(o) targets.append(t) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) config = get_config(args.config) vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split( config) assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1) score, class_iou = scores_gzsl(targets, outputs, n_class=len(visible_classes_test), seen_cls=cls_map_test[vals_cls], unseen_cls=cls_map_test[valu_cls]) print("Test results:") logger.write("Test results:") for k, v in score.items(): print(k + ': ' + json.dumps(v)) logger.write(k + ': ' + json.dumps(v)) score["Class IoU"] = {} for i in range(len(visible_classes_test)): score["Class IoU"][all_labels[ visible_classes_test[i]]] = class_iou[i] print("Class IoU: " + json.dumps(score["Class IoU"])) logger.write("Class IoU: " + json.dumps(score["Class IoU"])) print("Test finished.\n\n") logger.write("Test finished.\n\n") # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class, Acc_class_by_class = self.evaluator.Pixel_Accuracy_Class() mIoU, mIoU_by_class = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch) self.writer.add_scalar("val/mIoU", mIoU, epoch) self.writer.add_scalar("val/Acc", Acc, epoch) self.writer.add_scalar("val/Acc_class", Acc_class, epoch) self.writer.add_scalar("val/fwIoU", FWIoU, epoch) print("Validation:") print("[Epoch: %d, numImages: %5d]" % (epoch, i * self.args.batch_size + image.data.shape[0])) print(f"Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}") print(f"Loss: {test_loss:.3f}") for i, (class_name, acc_value, mIoU_value) in enumerate( zip(class_names, Acc_class_by_class, mIoU_by_class)): self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) print(class_names[i], "- acc:", acc_value, " mIoU:", mIoU_value) new_pred = mIoU is_best = False if new_pred > self.best_pred: is_best = True self.best_pred = new_pred self.saver.save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "best_pred": self.best_pred, }, is_best, )
class Trainer(BaseTrainer): def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() """ Get dataLoader """ config = get_config(args.config) vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, visibility_mask, cls_map, cls_map_test = get_split( config) assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1) dataset = get_dataset(config['DATAMODE'])( train=train, test=None, root=config['ROOT'], split=config['SPLIT']['TRAIN'], base_size=513, crop_size=config['IMAGE']['SIZE']['TRAIN'], mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'], config['IMAGE']['MEAN']['R']), warp=config['WARP_IMAGE'], scale=(0.5, 1.5), flip=True, visibility_mask=visibility_mask) print('train dataset:', len(dataset)) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=config['BATCH_SIZE']['TRAIN'], num_workers=config['NUM_WORKERS'], sampler=sampler) dataset_test = get_dataset(config['DATAMODE'])( train=None, test=val, root=config['ROOT'], split=config['SPLIT']['TEST'], base_size=513, crop_size=config['IMAGE']['SIZE']['TEST'], mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'], config['IMAGE']['MEAN']['R']), warp=config['WARP_IMAGE'], scale=None, flip=False) print('test dataset:', len(dataset_test)) loader_test = torch.utils.data.DataLoader( dataset=dataset_test, batch_size=config['BATCH_SIZE']['TEST'], num_workers=config['NUM_WORKERS'], shuffle=False) self.train_loader = loader self.val_loader = loader_test self.nclass = 34 # Define Dataloader # kwargs = {"num_workers": args.workers, "pin_memory": True} # (self.train_loader, self.val_loader, _, self.nclass,) = make_data_loader( # args, **kwargs # ) # Define network model = DeepLab( num_classes=self.nclass, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn, pretrained=args.imagenet_pretrained, imagenet_pretrained_path=args.imagenet_pretrained_path, ) train_params = [ { "params": model.get_1x_lr_params(), "lr": args.lr }, { "params": model.get_10x_lr_params(), "lr": args.lr * 10 }, ] # Define Optimizer optimizer = torch.optim.SGD( train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov, ) # Define Criterion # whether to use class balanced weights if args.use_balanced_weights: classes_weights_path = ( DATASETS_DIRS[args.dataset] / args.dataset + "_classes_weights.npy") if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.model, self.optimizer = model, optimizer # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError( f"=> no checkpoint found at '{args.resume}'") checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] if args.cuda: self.model.module.load_state_dict(checkpoint["state_dict"]) else: self.model.load_state_dict(checkpoint["state_dict"]) if not args.ft: self.optimizer.load_state_dict(checkpoint["optimizer"]) self.best_pred = checkpoint["best_pred"] print( f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" ) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def validation(self, epoch): self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc="\r") test_loss = 0.0 for i, sample in enumerate(tbar): # image, target = sample["image"], sample["label"] image, target = sample[0], sample[1] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output = self.model(image) loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class, Acc_class_by_class = self.evaluator.Pixel_Accuracy_Class() mIoU, mIoU_by_class = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch) self.writer.add_scalar("val/mIoU", mIoU, epoch) self.writer.add_scalar("val/Acc", Acc, epoch) self.writer.add_scalar("val/Acc_class", Acc_class, epoch) self.writer.add_scalar("val/fwIoU", FWIoU, epoch) print("Validation:") print("[Epoch: %d, numImages: %5d]" % (epoch, i * self.args.batch_size + image.data.shape[0])) print(f"Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}") print(f"Loss: {test_loss:.3f}") for i, (class_name, acc_value, mIoU_value) in enumerate( zip(CLASSES_NAMES, Acc_class_by_class, mIoU_by_class)): self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) print(CLASSES_NAMES[i], "- acc:", acc_value, " mIoU:", mIoU_value) new_pred = mIoU is_best = True self.best_pred = new_pred self.saver.save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "best_pred": self.best_pred, }, is_best, )
class Trainer: def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() """ Get dataLoader """ config = get_config(args.config) vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, visibility_mask, cls_map, cls_map_test = get_split( config) assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1) dataset = get_dataset(config['DATAMODE'])( train=train, test=None, root=config['ROOT'], split=config['SPLIT']['TRAIN'], base_size=513, crop_size=config['IMAGE']['SIZE']['TRAIN'], mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'], config['IMAGE']['MEAN']['R']), warp=config['WARP_IMAGE'], scale=(0.5, 1.5), flip=True, visibility_mask=visibility_mask) print('train dataset:', len(dataset)) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=config['BATCH_SIZE']['TRAIN'], num_workers=config['NUM_WORKERS'], sampler=sampler) dataset_test = get_dataset(config['DATAMODE'])( train=None, test=val, root=config['ROOT'], split=config['SPLIT']['TEST'], base_size=513, crop_size=config['IMAGE']['SIZE']['TEST'], mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'], config['IMAGE']['MEAN']['R']), warp=config['WARP_IMAGE'], scale=None, flip=False) print('test dataset:', len(dataset_test)) loader_test = torch.utils.data.DataLoader( dataset=dataset_test, batch_size=config['BATCH_SIZE']['TEST'], num_workers=config['NUM_WORKERS'], shuffle=False) self.train_loader = loader self.val_loader = loader_test self.nclass = 21 # Define Dataloader kwargs = {"num_workers": args.workers, "pin_memory": True} ( self.train_loader, self.val_loader, _, self.nclass, ) = make_data_loader(args, load_embedding=args.load_embedding, w2c_size=args.w2c_size, **kwargs) print('self.nclass', self.nclass) # Define network model = DeepLab( num_classes=self.nclass, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn, global_avg_pool_bn=args.global_avg_pool_bn, imagenet_pretrained_path=args.imagenet_pretrained_path, ) train_params = [ { "params": model.get_1x_lr_params(), "lr": args.lr }, { "params": model.get_10x_lr_params(), "lr": args.lr * 10 }, ] # Define Optimizer optimizer = torch.optim.SGD( train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov, ) # Define Generator generator = GMMNnetwork(args.noise_dim, args.embed_dim, args.hidden_size, args.feature_dim) optimizer_generator = torch.optim.Adam(generator.parameters(), lr=args.lr_generator) class_weight = torch.ones(self.nclass) class_weight[args.unseen_classes_idx_metric] = args.unseen_weight if args.cuda: class_weight = class_weight.cuda() self.criterion = SegmentationLosses( weight=class_weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.model, self.optimizer = model, optimizer self.criterion_generator = GMMNLoss(sigma=[2, 5, 10, 20, 40, 80], cuda=args.cuda).build_loss() self.generator, self.optimizer_generator = generator, optimizer_generator # Define Evaluator self.evaluator = Evaluator(self.nclass, args.seen_classes_idx_metric, args.unseen_classes_idx_metric) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() self.generator = self.generator.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError( f"=> no checkpoint found at '{args.resume}'") checkpoint = torch.load(args.resume) # args.start_epoch = checkpoint['epoch'] if args.random_last_layer: checkpoint["state_dict"][ "decoder.pred_conv.weight"] = torch.rand(( self.nclass, checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[1], checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[2], checkpoint["state_dict"] ["decoder.pred_conv.weight"].shape[3], )) checkpoint["state_dict"][ "decoder.pred_conv.bias"] = torch.rand(self.nclass) if args.cuda: self.model.module.load_state_dict(checkpoint["state_dict"]) else: self.model.load_state_dict(checkpoint["state_dict"]) # self.best_pred = checkpoint['best_pred'] print( f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" ) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def training(self, epoch, args): train_loss = 0.0 self.model.train() tbar = tqdm(self.train_loader) num_img_tr = len(self.train_loader) for i, sample in enumerate(tbar): if len(sample["image"]) > 1: image, target, embedding = ( sample["image"], sample["label"], sample["label_emb"], ) if self.args.cuda: image, target, embedding = ( image.cuda(), target.cuda(), embedding.cuda(), ) self.scheduler(self.optimizer, i, epoch, self.best_pred) # ===================real feature extraction===================== with torch.no_grad(): real_features = self.model.module.forward_before_class_prediction( image) # ===================fake feature generation===================== fake_features = torch.zeros(real_features.shape) if args.cuda: fake_features = fake_features.cuda() generator_loss_batch = 0.0 for ( count_sample_i, (real_features_i, target_i, embedding_i), ) in enumerate(zip(real_features, target, embedding)): generator_loss_sample = 0.0 ## reduce to real feature size real_features_i = (real_features_i.permute( 1, 2, 0).contiguous().view((-1, args.feature_dim))) target_i = nn.functional.interpolate( target_i.view(1, 1, target_i.shape[0], target_i.shape[1]), size=(real_features.shape[2], real_features.shape[3]), mode="nearest", ).view(-1) embedding_i = nn.functional.interpolate( embedding_i.view( 1, embedding_i.shape[0], embedding_i.shape[1], embedding_i.shape[2], ), size=(real_features.shape[2], real_features.shape[3]), mode="nearest", ) embedding_i = (embedding_i.permute(0, 2, 3, 1).contiguous().view( (-1, args.embed_dim))) fake_features_i = torch.zeros(real_features_i.shape) if args.cuda: fake_features_i = fake_features_i.cuda() unique_class = torch.unique(target_i) ## test if image has unseen class pixel, if yes means no training for generator and generated features for the whole image has_unseen_class = False for u_class in unique_class: if u_class in args.unseen_classes_idx_metric: has_unseen_class = True for idx_in in unique_class: if idx_in != 255: self.optimizer_generator.zero_grad() idx_class = target_i == idx_in real_features_class = real_features_i[idx_class] embedding_class = embedding_i[idx_class] z = torch.rand( (embedding_class.shape[0], args.noise_dim)) if args.cuda: z = z.cuda() fake_features_class = self.generator( embedding_class, z.float()) if (idx_in in args.seen_classes_idx_metric and not has_unseen_class): ## in order to avoid CUDA out of memory random_idx = torch.randint( low=0, high=fake_features_class.shape[0], size=(args.batch_size_generator, ), ) g_loss = self.criterion_generator( fake_features_class[random_idx], real_features_class[random_idx], ) generator_loss_sample += g_loss.item() g_loss.backward() self.optimizer_generator.step() fake_features_i[ idx_class] = fake_features_class.clone() generator_loss_batch += generator_loss_sample / len( unique_class) if args.real_seen_features and not has_unseen_class: fake_features[count_sample_i] = real_features_i.view(( fake_features.shape[2], fake_features.shape[3], args.feature_dim, )).permute(2, 0, 1) else: fake_features[count_sample_i] = fake_features_i.view(( fake_features.shape[2], fake_features.shape[3], args.feature_dim, )).permute(2, 0, 1) # ===================classification===================== self.optimizer.zero_grad() output = self.model.module.forward_class_prediction( fake_features.detach(), image.size()[2:]) loss = self.criterion(output, target) loss.backward() self.optimizer.step() train_loss += loss.item() # ===================log===================== tbar.set_description(f" G loss: {generator_loss_batch:.3f}" + " C loss: %.3f" % (train_loss / (i + 1))) self.writer.add_scalar("train/total_loss_iter", loss.item(), i + num_img_tr * epoch) self.writer.add_scalar("train/generator_loss", generator_loss_batch, i + num_img_tr * epoch) # Show 10 * 3 inference results each epoch if i % (num_img_tr // 10) == 0: global_step = i + num_img_tr * epoch self.summary.visualize_image( self.writer, self.args.dataset, image, target, output, global_step, ) self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch) print("[Epoch: %d, numImages: %5d]" % (epoch, i * self.args.batch_size + image.data.shape[0])) print(f"Loss: {train_loss:.3f}") if self.args.no_val: # save checkpoint every epoch is_best = False self.saver.save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "best_pred": self.best_pred, }, is_best, ) def validation(self, epoch, args): class_names = [ "background", # class 0 "aeroplane", # class 1 "bicycle", # class 2 "bird", # class 3 "boat", # class 4 "bottle", # class 5 "bus", # class 6 "car", # class 7 "cat", # class 8 "chair", # class 9 "cow", # class 10 "diningtable", # class 11 "dog", # class 12 "horse", # class 13 "motorbike", # class 14 "person", # class 15 "potted plant", # class 16 "sheep", # class 17 "sofa", # class 18 "train", # class 19 "tv/monitor", # class 20 ] self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc="\r") test_loss = 0.0 saved_images = {} saved_target = {} saved_prediction = {} for idx_unseen_class in args.unseen_classes_idx_metric: saved_images[idx_unseen_class] = [] saved_target[idx_unseen_class] = [] saved_prediction[idx_unseen_class] = [] for i, sample in enumerate(tbar): image, target, embedding = ( sample["image"], sample["label"], sample["label_emb"], ) if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output = self.model(image) loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) ## save image for tensorboard for idx_unseen_class in args.unseen_classes_idx_metric: if len((target.reshape(-1) == idx_unseen_class).nonzero()) > 0: if len(saved_images[idx_unseen_class] ) < args.saved_validation_images: saved_images[idx_unseen_class].append( image.clone().cpu()) saved_target[idx_unseen_class].append( target.clone().cpu()) saved_prediction[idx_unseen_class].append( output.clone().cpu()) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) # Fast test during the training Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy() ( Acc_class, Acc_class_by_class, Acc_class_seen, Acc_class_unseen, ) = self.evaluator.Pixel_Accuracy_Class() ( mIoU, mIoU_by_class, mIoU_seen, mIoU_unseen, ) = self.evaluator.Mean_Intersection_over_Union() ( FWIoU, FWIoU_seen, FWIoU_unseen, ) = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar("val_overall/total_loss_epoch", test_loss, epoch) self.writer.add_scalar("val_overall/mIoU", mIoU, epoch) self.writer.add_scalar("val_overall/Acc", Acc, epoch) self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch) self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch) self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch) self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch) self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch) self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch) self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch) self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch) self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch) self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch) print("Validation:") print("[Epoch: %d, numImages: %5d]" % (epoch, i * self.args.batch_size + image.data.shape[0])) print(f"Loss: {test_loss:.3f}") print( f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}" ) print("Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen)) print("Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen)) for class_name, acc_value, mIoU_value in zip(class_names, Acc_class_by_class, mIoU_by_class): self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value) new_pred = mIoU_unseen is_best = True self.best_pred = new_pred self.saver.save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "best_pred": self.best_pred, }, is_best, generator_state={ "epoch": epoch + 1, "state_dict": self.generator.state_dict(), "optimizer": self.optimizer.state_dict(), "best_pred": self.best_pred, }, ) global_step = epoch + 1 for idx_unseen_class in args.unseen_classes_idx_metric: if len(saved_images[idx_unseen_class]) > 0: nb_image = len(saved_images[idx_unseen_class]) if nb_image > args.saved_validation_images: nb_image = args.saved_validation_images for i in range(nb_image): self.summary.visualize_image_validation( self.writer, self.args.dataset, saved_images[idx_unseen_class][i], saved_target[idx_unseen_class][i], saved_prediction[idx_unseen_class][i], global_step, name="validation_" + class_names[idx_unseen_class] + "_" + str(i), nb_image=1, ) self.evaluator.reset()