예제 #1
0
class Trainer:
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            _,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(
            num_classes=self.nclass,
            output_stride=args.out_stride,
            sync_bn=args.sync_bn,
            freeze_bn=args.freeze_bn,
            imagenet_pretrained_path=args.imagenet_pretrained_path,
        )
        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = (
                DATASETS_DIRS[args.dataset] / args.dataset +
                "_classes_weights.npy")
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass, args.seen_classes_idx_metric,
                                   args.unseen_classes_idx_metric)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    f"=> no checkpoint found at '{args.resume}'")
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]

            if args.random_last_layer:
                checkpoint["state_dict"][
                    "decoder.pred_conv.weight"] = torch.rand((
                        self.nclass,
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[1],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[2],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[3],
                    ))
                checkpoint["state_dict"][
                    "decoder.pred_conv.bias"] = torch.rand(self.nclass)

            if args.nonlinear_last_layer:
                if args.cuda:
                    self.model.module.deeplab.load_state_dict(
                        checkpoint["state_dict"])
                else:
                    self.model.deeplab.load_state_dict(
                        checkpoint["state_dict"])
            else:
                if args.cuda:
                    self.model.module.load_state_dict(checkpoint["state_dict"])
                else:
                    self.model.load_state_dict(checkpoint["state_dict"])

            if not args.ft:
                if not args.nonlinear_last_layer:
                    self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.best_pred = checkpoint["best_pred"]
            print(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def validation(self, epoch, args):
        self.model.eval()
        self.evaluator.reset()
        all_target = []
        all_pred = []
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample["image"], sample["label"]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                if args.nonlinear_last_layer:
                    output = self.model(image, image.size()[2:])
                else:
                    output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)

            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            all_target.append(target)
            all_pred.append(pred)

        # Fast test during the training
        Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy()
        (
            Acc_class,
            Acc_class_by_class,
            Acc_class_seen,
            Acc_class_unseen,
        ) = self.evaluator.Pixel_Accuracy_Class()
        (
            mIoU,
            mIoU_by_class,
            mIoU_seen,
            mIoU_unseen,
        ) = self.evaluator.Mean_Intersection_over_Union()
        (
            FWIoU,
            FWIoU_seen,
            FWIoU_unseen,
        ) = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar("val_overall/total_loss_epoch", test_loss,
                               epoch)
        self.writer.add_scalar("val_overall/mIoU", mIoU, epoch)
        self.writer.add_scalar("val_overall/Acc", Acc, epoch)
        self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch)
        self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch)

        self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch)
        self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch)
        self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch)
        self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch)

        self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch)
        self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch)
        self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch)
        self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch)

        print("Validation:")
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print(f"Loss: {test_loss:.3f}")
        print(
            f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}"
        )
        print("Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen))
        print("Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen))

        for class_name, acc_value, mIoU_value in zip(CLASSES_NAMES,
                                                     Acc_class_by_class,
                                                     mIoU_by_class):
            self.writer.add_scalar("Acc_by_class/" + class_name, acc_value,
                                   epoch)
            self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value,
                                   epoch)
            print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value)
예제 #2
0
class Trainer:
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        """
            Get dataLoader
        """
        #         config = get_config(args.config)
        #         vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split(config)
        #         assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1)
        #         print('seen_classes', vals_cls)
        #         print('novel_classes', valu_cls)
        #         print('all_labels', all_labels)
        #         print('visible_classes', visible_classes)
        #         print('visible_classes_test', visible_classes_test)
        #         print('train', train[:10], len(train))
        #         print('val', val[:10], len(val))
        #         print('cls_map', cls_map)
        #         print('cls_map_test', cls_map_test)

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            _,
            self.nclass,
        ) = make_data_loader(args,
                             load_embedding=args.load_embedding,
                             w2c_size=args.w2c_size,
                             **kwargs)
        print('self.nclass', self.nclass)  # 33

        model = DeepLab(
            num_classes=self.nclass,
            output_stride=args.out_stride,
            sync_bn=args.sync_bn,
            freeze_bn=args.freeze_bn,
            global_avg_pool_bn=args.global_avg_pool_bn,
            imagenet_pretrained_path=args.imagenet_pretrained_path,
        )

        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Generator
        generator = GMMNnetwork(args.noise_dim, args.embed_dim,
                                args.hidden_size, args.feature_dim)
        optimizer_generator = torch.optim.Adam(generator.parameters(),
                                               lr=args.lr_generator)

        class_weight = torch.ones(self.nclass)
        class_weight[args.unseen_classes_idx_metric] = args.unseen_weight
        if args.cuda:
            class_weight = class_weight.cuda()

        self.criterion = SegmentationLosses(
            weight=class_weight,
            cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.criterion_generator = GMMNLoss(sigma=[2, 5, 10, 20, 40, 80],
                                            cuda=args.cuda).build_loss()
        self.generator, self.optimizer_generator = generator, optimizer_generator

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass, args.seen_classes_idx_metric,
                                   args.unseen_classes_idx_metric)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            self.generator = self.generator.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    f"=> no checkpoint found at '{args.resume}'")
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']

            if args.random_last_layer:
                checkpoint["state_dict"][
                    "decoder.pred_conv.weight"] = torch.rand((
                        self.nclass,
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[1],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[2],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[3],
                    ))
                checkpoint["state_dict"][
                    "decoder.pred_conv.bias"] = torch.rand(self.nclass)

            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])

            # self.best_pred = checkpoint['best_pred']
            print(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch, args):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            if len(sample["image"]) > 1:
                image, target, embedding = (
                    sample["image"],
                    sample["label"],
                    sample["label_emb"],
                )
                if self.args.cuda:
                    image, target, embedding = (
                        image.cuda(),
                        target.cuda(),
                        embedding.cuda(),
                    )
                self.scheduler(self.optimizer, i, epoch, self.best_pred)
                # ===================real feature extraction=====================
                with torch.no_grad():
                    real_features = self.model.module.forward_before_class_prediction(
                        image)

                # ===================fake feature generation=====================
                fake_features = torch.zeros(real_features.shape)
                if args.cuda:
                    fake_features = fake_features.cuda()
                generator_loss_batch = 0.0
                for (
                        count_sample_i,
                    (real_features_i, target_i, embedding_i),
                ) in enumerate(zip(real_features, target, embedding)):
                    generator_loss_sample = 0.0
                    ## reduce to real feature size
                    real_features_i = (real_features_i.permute(
                        1, 2, 0).contiguous().view((-1, args.feature_dim)))
                    target_i = nn.functional.interpolate(
                        target_i.view(1, 1, target_i.shape[0],
                                      target_i.shape[1]),
                        size=(real_features.shape[2], real_features.shape[3]),
                        mode="nearest",
                    ).view(-1)
                    embedding_i = nn.functional.interpolate(
                        embedding_i.view(
                            1,
                            embedding_i.shape[0],
                            embedding_i.shape[1],
                            embedding_i.shape[2],
                        ),
                        size=(real_features.shape[2], real_features.shape[3]),
                        mode="nearest",
                    )

                    embedding_i = (embedding_i.permute(0, 2, 3,
                                                       1).contiguous().view(
                                                           (-1,
                                                            args.embed_dim)))

                    fake_features_i = torch.zeros(real_features_i.shape)
                    if args.cuda:
                        fake_features_i = fake_features_i.cuda()

                    unique_class = torch.unique(target_i)

                    ## test if image has unseen class pixel, if yes means no training for generator and generated features for the whole image
                    has_unseen_class = False
                    for u_class in unique_class:
                        if u_class in args.unseen_classes_idx_metric:
                            has_unseen_class = True

                    for idx_in in unique_class:
                        if idx_in != 255:
                            self.optimizer_generator.zero_grad()
                            idx_class = target_i == idx_in
                            real_features_class = real_features_i[idx_class]
                            embedding_class = embedding_i[idx_class]

                            z = torch.rand(
                                (embedding_class.shape[0], args.noise_dim))
                            if args.cuda:
                                z = z.cuda()

                            fake_features_class = self.generator(
                                embedding_class, z.float())

                            if (idx_in in args.seen_classes_idx_metric
                                    and not has_unseen_class):
                                ## in order to avoid CUDA out of memory
                                random_idx = torch.randint(
                                    low=0,
                                    high=fake_features_class.shape[0],
                                    size=(args.batch_size_generator, ),
                                )
                                g_loss = self.criterion_generator(
                                    fake_features_class[random_idx],
                                    real_features_class[random_idx],
                                )
                                generator_loss_sample += g_loss.item()
                                g_loss.backward()
                                self.optimizer_generator.step()

                            fake_features_i[
                                idx_class] = fake_features_class.clone()
                    generator_loss_batch += generator_loss_sample / len(
                        unique_class)
                    if args.real_seen_features and not has_unseen_class:
                        fake_features[count_sample_i] = real_features_i.view((
                            fake_features.shape[2],
                            fake_features.shape[3],
                            args.feature_dim,
                        )).permute(2, 0, 1)
                    else:
                        fake_features[count_sample_i] = fake_features_i.view((
                            fake_features.shape[2],
                            fake_features.shape[3],
                            args.feature_dim,
                        )).permute(2, 0, 1)
                # ===================classification=====================
                self.optimizer.zero_grad()
                output = self.model.module.forward_class_prediction(
                    fake_features.detach(),
                    image.size()[2:])
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()
                # ===================log=====================
                tbar.set_description(f" G loss: {generator_loss_batch:.3f}" +
                                     " C loss: %.3f" % (train_loss / (i + 1)))
                self.writer.add_scalar("train/total_loss_iter", loss.item(),
                                       i + num_img_tr * epoch)
                self.writer.add_scalar("train/generator_loss",
                                       generator_loss_batch,
                                       i + num_img_tr * epoch)

                # Show 10 * 3 inference results each epoch
                if i % (num_img_tr // 10) == 0:
                    global_step = i + num_img_tr * epoch
                    self.summary.visualize_image(
                        self.writer,
                        self.args.dataset,
                        image,
                        target,
                        output,
                        global_step,
                    )

        self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch)
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print(f"Loss: {train_loss:.3f}")

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.module.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "best_pred": self.best_pred,
                },
                is_best,
            )

    def validation(self, epoch, args):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0

        saved_images = {}
        saved_target = {}
        saved_prediction = {}
        for idx_unseen_class in args.unseen_classes_idx_metric:
            saved_images[idx_unseen_class] = []
            saved_target[idx_unseen_class] = []
            saved_prediction[idx_unseen_class] = []

        targets, outputs = [], []
        log_file = './logs_context_step_2_GMMN.txt'
        logger = logWritter(log_file)

        for i, sample in enumerate(tbar):
            image, target, embedding = (
                sample["image"],
                sample["label"],
                sample["label_emb"],
            )
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            ## save image for tensorboard
            for idx_unseen_class in args.unseen_classes_idx_metric:
                if len((target.reshape(-1) == idx_unseen_class).nonzero()) > 0:
                    if len(saved_images[idx_unseen_class]
                           ) < args.saved_validation_images:
                        saved_images[idx_unseen_class].append(
                            image.clone().cpu())
                        saved_target[idx_unseen_class].append(
                            target.clone().cpu())
                        saved_prediction[idx_unseen_class].append(
                            output.clone().cpu())

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy().astype(np.int64)
            pred = np.argmax(pred, axis=1)
            for o, t in zip(pred, target):
                outputs.append(o)
                targets.append(t)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        config = get_config(args.config)
        vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split(
            config)
        assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] -
                1)
        score, class_iou = scores_gzsl(targets,
                                       outputs,
                                       n_class=len(visible_classes_test),
                                       seen_cls=cls_map_test[vals_cls],
                                       unseen_cls=cls_map_test[valu_cls])

        print("Test results:")
        logger.write("Test results:")

        for k, v in score.items():
            print(k + ': ' + json.dumps(v))
            logger.write(k + ': ' + json.dumps(v))

        score["Class IoU"] = {}
        visible_classes_test = sorted(visible_classes_test)
        for i in range(len(visible_classes_test)):
            score["Class IoU"][all_labels[
                visible_classes_test[i]]] = class_iou[i]
        print("Class IoU: " + json.dumps(score["Class IoU"]))
        logger.write("Class IoU: " + json.dumps(score["Class IoU"]))

        print("Test finished.\n\n")
        logger.write("Test finished.\n\n")

        # Fast test during the training
        Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy()
        (
            Acc_class,
            Acc_class_by_class,
            Acc_class_seen,
            Acc_class_unseen,
        ) = self.evaluator.Pixel_Accuracy_Class()
        (
            mIoU,
            mIoU_by_class,
            mIoU_seen,
            mIoU_unseen,
        ) = self.evaluator.Mean_Intersection_over_Union()
        (
            FWIoU,
            FWIoU_seen,
            FWIoU_unseen,
        ) = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar("val_overall/total_loss_epoch", test_loss,
                               epoch)
        self.writer.add_scalar("val_overall/mIoU", mIoU, epoch)
        self.writer.add_scalar("val_overall/Acc", Acc, epoch)
        self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch)
        self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch)

        self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch)
        self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch)
        self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch)
        self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch)

        self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch)
        self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch)
        self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch)
        self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch)

        print("Validation:")
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print(f"Loss: {test_loss:.3f}")
        print(
            f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}"
        )
        print("Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen))
        print("Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen))

        for class_name, acc_value, mIoU_value in zip(CLASSES_NAMES,
                                                     Acc_class_by_class,
                                                     mIoU_by_class):
            self.writer.add_scalar("Acc_by_class/" + class_name, acc_value,
                                   epoch)
            self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value,
                                   epoch)
            print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value)

        new_pred = mIoU_unseen

        is_best = True
        self.best_pred = new_pred
        self.saver.save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": self.model.module.state_dict(),
                "optimizer": self.optimizer.state_dict(),
                "best_pred": self.best_pred,
            },
            is_best,
            generator_state={
                "epoch": epoch + 1,
                "state_dict": self.generator.state_dict(),
                "optimizer": self.optimizer.state_dict(),
                "best_pred": self.best_pred,
            },
        )

        global_step = epoch + 1
        for idx_unseen_class in args.unseen_classes_idx_metric:
            if len(saved_images[idx_unseen_class]) > 0:
                nb_image = len(saved_images[idx_unseen_class])
                if nb_image > args.saved_validation_images:
                    nb_image = args.saved_validation_images
                for i in range(nb_image):
                    self.summary.visualize_image_validation(
                        self.writer,
                        self.args.dataset,
                        saved_images[idx_unseen_class][i],
                        saved_target[idx_unseen_class][i],
                        saved_prediction[idx_unseen_class][i],
                        global_step,
                        name="validation_" + CLASSES_NAMES[idx_unseen_class] +
                        "_" + str(i),
                        nb_image=1,
                    )

        self.evaluator.reset()
예제 #3
0
class Trainer(BaseTrainer):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        """
            Get dataLoader
        """
        #         config = get_config(args.config)
        #         vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split(config)
        #         assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1)
        #         print('seen_classes', vals_cls)
        #         print('novel_classes', valu_cls)
        #         print('all_labels', all_labels)
        #         print('visible_classes', visible_classes)
        #         print('visible_classes_test', visible_classes_test)
        #         print('train', train[:10], len(train))
        #         print('val', val[:10], len(val))
        #         print('cls_map', cls_map)
        #         print('cls_map_test', cls_map_test)

        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            _,
            self.nclass,
        ) = make_data_loader(args, **kwargs)
        print('self.nclass', self.nclass)

        # Define network
        model = DeepLab(
            num_classes=self.nclass,
            output_stride=args.out_stride,
            sync_bn=args.sync_bn,
            freeze_bn=False,
            pretrained=args.imagenet_pretrained,
            imagenet_pretrained_path=args.imagenet_pretrained_path,
        )

        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = (
                DATASETS_DIRS[args.dataset] / args.dataset +
                "_classes_weights.npy")
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        if args.imagenet_pretrained_path is not None:
            state_dict = torch.load(args.imagenet_pretrained_path)
            if 'state_dict' in state_dict.keys():
                self.model.load_state_dict(state_dict['state_dict'])
            else:
                #print(model.state_dict().keys())#['scale.layer1.conv1.conv.weight'])
                #print(state_dict.items().keys())
                new_dict = {}
                for k, v in state_dict.items():
                    #print(k[11:])
                    new_dict[k[11:]] = v
                self.model.load_state_dict(
                    new_dict, strict=False
                )  # make strict=True to debug if checkpoint is loaded correctly or not if performance is low

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    f"=> no checkpoint found at '{args.resume}'")
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.best_pred = checkpoint["best_pred"]
            print(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def validation(self, epoch, args):
        class_names = CLASSES_NAMES[:20]
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        torch.set_printoptions(profile="full")
        targets, outputs = [], []
        log_file = './logs_voc12_step_1.txt'
        logger = logWritter(log_file)
        for i, sample in enumerate(tbar):
            image, target = sample["image"], sample["label"]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            target = resize_target(target, s=output.size()[2:]).cuda()
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            #             print('pred', pred[:, 100:105, 100:120])
            #             print('target', target[:, 100:105, 100:120])
            for o, t in zip(pred, target):
                outputs.append(o)
                targets.append(t)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        config = get_config(args.config)
        vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split(
            config)
        assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] -
                1)
        score, class_iou = scores_gzsl(targets,
                                       outputs,
                                       n_class=len(visible_classes_test),
                                       seen_cls=cls_map_test[vals_cls],
                                       unseen_cls=cls_map_test[valu_cls])

        print("Test results:")
        logger.write("Test results:")

        for k, v in score.items():
            print(k + ': ' + json.dumps(v))
            logger.write(k + ': ' + json.dumps(v))

        score["Class IoU"] = {}
        for i in range(len(visible_classes_test)):
            score["Class IoU"][all_labels[
                visible_classes_test[i]]] = class_iou[i]
        print("Class IoU: " + json.dumps(score["Class IoU"]))
        logger.write("Class IoU: " + json.dumps(score["Class IoU"]))

        print("Test finished.\n\n")
        logger.write("Test finished.\n\n")

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class, Acc_class_by_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, mIoU_by_class = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch)
        self.writer.add_scalar("val/mIoU", mIoU, epoch)
        self.writer.add_scalar("val/Acc", Acc, epoch)
        self.writer.add_scalar("val/Acc_class", Acc_class, epoch)
        self.writer.add_scalar("val/fwIoU", FWIoU, epoch)
        print("Validation:")
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print(f"Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}")
        print(f"Loss: {test_loss:.3f}")

        for i, (class_name, acc_value, mIoU_value) in enumerate(
                zip(class_names, Acc_class_by_class, mIoU_by_class)):
            self.writer.add_scalar("Acc_by_class/" + class_name, acc_value,
                                   epoch)
            self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value,
                                   epoch)
            print(class_names[i], "- acc:", acc_value, " mIoU:", mIoU_value)

        new_pred = mIoU
        is_best = False
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        self.saver.save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": self.model.module.state_dict(),
                "optimizer": self.optimizer.state_dict(),
                "best_pred": self.best_pred,
            },
            is_best,
        )
예제 #4
0
class Trainer(BaseTrainer):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        """
            Get dataLoader
        """
        config = get_config(args.config)
        vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, visibility_mask, cls_map, cls_map_test = get_split(
            config)
        assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] -
                1)

        dataset = get_dataset(config['DATAMODE'])(
            train=train,
            test=None,
            root=config['ROOT'],
            split=config['SPLIT']['TRAIN'],
            base_size=513,
            crop_size=config['IMAGE']['SIZE']['TRAIN'],
            mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'],
                  config['IMAGE']['MEAN']['R']),
            warp=config['WARP_IMAGE'],
            scale=(0.5, 1.5),
            flip=True,
            visibility_mask=visibility_mask)
        print('train dataset:', len(dataset))

        loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['BATCH_SIZE']['TRAIN'],
            num_workers=config['NUM_WORKERS'],
            sampler=sampler)

        dataset_test = get_dataset(config['DATAMODE'])(
            train=None,
            test=val,
            root=config['ROOT'],
            split=config['SPLIT']['TEST'],
            base_size=513,
            crop_size=config['IMAGE']['SIZE']['TEST'],
            mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'],
                  config['IMAGE']['MEAN']['R']),
            warp=config['WARP_IMAGE'],
            scale=None,
            flip=False)
        print('test dataset:', len(dataset_test))

        loader_test = torch.utils.data.DataLoader(
            dataset=dataset_test,
            batch_size=config['BATCH_SIZE']['TEST'],
            num_workers=config['NUM_WORKERS'],
            shuffle=False)

        self.train_loader = loader
        self.val_loader = loader_test
        self.nclass = 34

        # Define Dataloader
        # kwargs = {"num_workers": args.workers, "pin_memory": True}
        # (self.train_loader, self.val_loader, _, self.nclass,) = make_data_loader(
        #     args, **kwargs
        # )

        # Define network
        model = DeepLab(
            num_classes=self.nclass,
            output_stride=args.out_stride,
            sync_bn=args.sync_bn,
            freeze_bn=args.freeze_bn,
            pretrained=args.imagenet_pretrained,
            imagenet_pretrained_path=args.imagenet_pretrained_path,
        )

        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = (
                DATASETS_DIRS[args.dataset] / args.dataset +
                "_classes_weights.npy")
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    f"=> no checkpoint found at '{args.resume}'")
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.best_pred = checkpoint["best_pred"]
            print(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            # image, target = sample["image"], sample["label"]
            image, target = sample[0], sample[1]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class, Acc_class_by_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, mIoU_by_class = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch)
        self.writer.add_scalar("val/mIoU", mIoU, epoch)
        self.writer.add_scalar("val/Acc", Acc, epoch)
        self.writer.add_scalar("val/Acc_class", Acc_class, epoch)
        self.writer.add_scalar("val/fwIoU", FWIoU, epoch)
        print("Validation:")
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print(f"Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}")
        print(f"Loss: {test_loss:.3f}")

        for i, (class_name, acc_value, mIoU_value) in enumerate(
                zip(CLASSES_NAMES, Acc_class_by_class, mIoU_by_class)):
            self.writer.add_scalar("Acc_by_class/" + class_name, acc_value,
                                   epoch)
            self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value,
                                   epoch)
            print(CLASSES_NAMES[i], "- acc:", acc_value, " mIoU:", mIoU_value)

        new_pred = mIoU
        is_best = True
        self.best_pred = new_pred
        self.saver.save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": self.model.module.state_dict(),
                "optimizer": self.optimizer.state_dict(),
                "best_pred": self.best_pred,
            },
            is_best,
        )
예제 #5
0
class Trainer:
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        """
            Get dataLoader
        """
        config = get_config(args.config)
        vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, visibility_mask, cls_map, cls_map_test = get_split(
            config)
        assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] -
                1)

        dataset = get_dataset(config['DATAMODE'])(
            train=train,
            test=None,
            root=config['ROOT'],
            split=config['SPLIT']['TRAIN'],
            base_size=513,
            crop_size=config['IMAGE']['SIZE']['TRAIN'],
            mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'],
                  config['IMAGE']['MEAN']['R']),
            warp=config['WARP_IMAGE'],
            scale=(0.5, 1.5),
            flip=True,
            visibility_mask=visibility_mask)
        print('train dataset:', len(dataset))

        loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['BATCH_SIZE']['TRAIN'],
            num_workers=config['NUM_WORKERS'],
            sampler=sampler)

        dataset_test = get_dataset(config['DATAMODE'])(
            train=None,
            test=val,
            root=config['ROOT'],
            split=config['SPLIT']['TEST'],
            base_size=513,
            crop_size=config['IMAGE']['SIZE']['TEST'],
            mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'],
                  config['IMAGE']['MEAN']['R']),
            warp=config['WARP_IMAGE'],
            scale=None,
            flip=False)
        print('test dataset:', len(dataset_test))

        loader_test = torch.utils.data.DataLoader(
            dataset=dataset_test,
            batch_size=config['BATCH_SIZE']['TEST'],
            num_workers=config['NUM_WORKERS'],
            shuffle=False)

        self.train_loader = loader
        self.val_loader = loader_test
        self.nclass = 21

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            _,
            self.nclass,
        ) = make_data_loader(args,
                             load_embedding=args.load_embedding,
                             w2c_size=args.w2c_size,
                             **kwargs)
        print('self.nclass', self.nclass)

        # Define network
        model = DeepLab(
            num_classes=self.nclass,
            output_stride=args.out_stride,
            sync_bn=args.sync_bn,
            freeze_bn=args.freeze_bn,
            global_avg_pool_bn=args.global_avg_pool_bn,
            imagenet_pretrained_path=args.imagenet_pretrained_path,
        )
        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Generator
        generator = GMMNnetwork(args.noise_dim, args.embed_dim,
                                args.hidden_size, args.feature_dim)
        optimizer_generator = torch.optim.Adam(generator.parameters(),
                                               lr=args.lr_generator)

        class_weight = torch.ones(self.nclass)
        class_weight[args.unseen_classes_idx_metric] = args.unseen_weight
        if args.cuda:
            class_weight = class_weight.cuda()

        self.criterion = SegmentationLosses(
            weight=class_weight,
            cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.criterion_generator = GMMNLoss(sigma=[2, 5, 10, 20, 40, 80],
                                            cuda=args.cuda).build_loss()
        self.generator, self.optimizer_generator = generator, optimizer_generator

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass, args.seen_classes_idx_metric,
                                   args.unseen_classes_idx_metric)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            self.generator = self.generator.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    f"=> no checkpoint found at '{args.resume}'")
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']

            if args.random_last_layer:
                checkpoint["state_dict"][
                    "decoder.pred_conv.weight"] = torch.rand((
                        self.nclass,
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[1],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[2],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[3],
                    ))
                checkpoint["state_dict"][
                    "decoder.pred_conv.bias"] = torch.rand(self.nclass)

            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])

            # self.best_pred = checkpoint['best_pred']
            print(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch, args):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            if len(sample["image"]) > 1:
                image, target, embedding = (
                    sample["image"],
                    sample["label"],
                    sample["label_emb"],
                )
                if self.args.cuda:
                    image, target, embedding = (
                        image.cuda(),
                        target.cuda(),
                        embedding.cuda(),
                    )
                self.scheduler(self.optimizer, i, epoch, self.best_pred)
                # ===================real feature extraction=====================
                with torch.no_grad():
                    real_features = self.model.module.forward_before_class_prediction(
                        image)

                # ===================fake feature generation=====================
                fake_features = torch.zeros(real_features.shape)
                if args.cuda:
                    fake_features = fake_features.cuda()
                generator_loss_batch = 0.0
                for (
                        count_sample_i,
                    (real_features_i, target_i, embedding_i),
                ) in enumerate(zip(real_features, target, embedding)):
                    generator_loss_sample = 0.0
                    ## reduce to real feature size
                    real_features_i = (real_features_i.permute(
                        1, 2, 0).contiguous().view((-1, args.feature_dim)))
                    target_i = nn.functional.interpolate(
                        target_i.view(1, 1, target_i.shape[0],
                                      target_i.shape[1]),
                        size=(real_features.shape[2], real_features.shape[3]),
                        mode="nearest",
                    ).view(-1)
                    embedding_i = nn.functional.interpolate(
                        embedding_i.view(
                            1,
                            embedding_i.shape[0],
                            embedding_i.shape[1],
                            embedding_i.shape[2],
                        ),
                        size=(real_features.shape[2], real_features.shape[3]),
                        mode="nearest",
                    )

                    embedding_i = (embedding_i.permute(0, 2, 3,
                                                       1).contiguous().view(
                                                           (-1,
                                                            args.embed_dim)))

                    fake_features_i = torch.zeros(real_features_i.shape)
                    if args.cuda:
                        fake_features_i = fake_features_i.cuda()

                    unique_class = torch.unique(target_i)

                    ## test if image has unseen class pixel, if yes means no training for generator and generated features for the whole image
                    has_unseen_class = False
                    for u_class in unique_class:
                        if u_class in args.unseen_classes_idx_metric:
                            has_unseen_class = True

                    for idx_in in unique_class:
                        if idx_in != 255:
                            self.optimizer_generator.zero_grad()
                            idx_class = target_i == idx_in
                            real_features_class = real_features_i[idx_class]
                            embedding_class = embedding_i[idx_class]

                            z = torch.rand(
                                (embedding_class.shape[0], args.noise_dim))
                            if args.cuda:
                                z = z.cuda()

                            fake_features_class = self.generator(
                                embedding_class, z.float())

                            if (idx_in in args.seen_classes_idx_metric
                                    and not has_unseen_class):
                                ## in order to avoid CUDA out of memory
                                random_idx = torch.randint(
                                    low=0,
                                    high=fake_features_class.shape[0],
                                    size=(args.batch_size_generator, ),
                                )
                                g_loss = self.criterion_generator(
                                    fake_features_class[random_idx],
                                    real_features_class[random_idx],
                                )
                                generator_loss_sample += g_loss.item()
                                g_loss.backward()
                                self.optimizer_generator.step()

                            fake_features_i[
                                idx_class] = fake_features_class.clone()
                    generator_loss_batch += generator_loss_sample / len(
                        unique_class)
                    if args.real_seen_features and not has_unseen_class:
                        fake_features[count_sample_i] = real_features_i.view((
                            fake_features.shape[2],
                            fake_features.shape[3],
                            args.feature_dim,
                        )).permute(2, 0, 1)
                    else:
                        fake_features[count_sample_i] = fake_features_i.view((
                            fake_features.shape[2],
                            fake_features.shape[3],
                            args.feature_dim,
                        )).permute(2, 0, 1)
                # ===================classification=====================
                self.optimizer.zero_grad()
                output = self.model.module.forward_class_prediction(
                    fake_features.detach(),
                    image.size()[2:])
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()
                # ===================log=====================
                tbar.set_description(f" G loss: {generator_loss_batch:.3f}" +
                                     " C loss: %.3f" % (train_loss / (i + 1)))
                self.writer.add_scalar("train/total_loss_iter", loss.item(),
                                       i + num_img_tr * epoch)
                self.writer.add_scalar("train/generator_loss",
                                       generator_loss_batch,
                                       i + num_img_tr * epoch)

                # Show 10 * 3 inference results each epoch
                if i % (num_img_tr // 10) == 0:
                    global_step = i + num_img_tr * epoch
                    self.summary.visualize_image(
                        self.writer,
                        self.args.dataset,
                        image,
                        target,
                        output,
                        global_step,
                    )

        self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch)
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print(f"Loss: {train_loss:.3f}")

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.module.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "best_pred": self.best_pred,
                },
                is_best,
            )

    def validation(self, epoch, args):
        class_names = [
            "background",  # class 0
            "aeroplane",  # class 1
            "bicycle",  # class 2
            "bird",  # class 3
            "boat",  # class 4
            "bottle",  # class 5
            "bus",  # class 6
            "car",  # class 7
            "cat",  # class 8
            "chair",  # class 9
            "cow",  # class 10
            "diningtable",  # class 11
            "dog",  # class 12
            "horse",  # class 13
            "motorbike",  # class 14
            "person",  # class 15
            "potted plant",  # class 16
            "sheep",  # class 17
            "sofa",  # class 18
            "train",  # class 19
            "tv/monitor",  # class 20
        ]
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0

        saved_images = {}
        saved_target = {}
        saved_prediction = {}
        for idx_unseen_class in args.unseen_classes_idx_metric:
            saved_images[idx_unseen_class] = []
            saved_target[idx_unseen_class] = []
            saved_prediction[idx_unseen_class] = []

        for i, sample in enumerate(tbar):
            image, target, embedding = (
                sample["image"],
                sample["label"],
                sample["label_emb"],
            )
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            ## save image for tensorboard
            for idx_unseen_class in args.unseen_classes_idx_metric:
                if len((target.reshape(-1) == idx_unseen_class).nonzero()) > 0:
                    if len(saved_images[idx_unseen_class]
                           ) < args.saved_validation_images:
                        saved_images[idx_unseen_class].append(
                            image.clone().cpu())
                        saved_target[idx_unseen_class].append(
                            target.clone().cpu())
                        saved_prediction[idx_unseen_class].append(
                            output.clone().cpu())

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy()
        (
            Acc_class,
            Acc_class_by_class,
            Acc_class_seen,
            Acc_class_unseen,
        ) = self.evaluator.Pixel_Accuracy_Class()
        (
            mIoU,
            mIoU_by_class,
            mIoU_seen,
            mIoU_unseen,
        ) = self.evaluator.Mean_Intersection_over_Union()
        (
            FWIoU,
            FWIoU_seen,
            FWIoU_unseen,
        ) = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar("val_overall/total_loss_epoch", test_loss,
                               epoch)
        self.writer.add_scalar("val_overall/mIoU", mIoU, epoch)
        self.writer.add_scalar("val_overall/Acc", Acc, epoch)
        self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch)
        self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch)

        self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch)
        self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch)
        self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch)
        self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch)

        self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch)
        self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch)
        self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch)
        self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch)

        print("Validation:")
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print(f"Loss: {test_loss:.3f}")
        print(
            f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}"
        )
        print("Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen))
        print("Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen))

        for class_name, acc_value, mIoU_value in zip(class_names,
                                                     Acc_class_by_class,
                                                     mIoU_by_class):
            self.writer.add_scalar("Acc_by_class/" + class_name, acc_value,
                                   epoch)
            self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value,
                                   epoch)
            print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value)

        new_pred = mIoU_unseen

        is_best = True
        self.best_pred = new_pred
        self.saver.save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": self.model.module.state_dict(),
                "optimizer": self.optimizer.state_dict(),
                "best_pred": self.best_pred,
            },
            is_best,
            generator_state={
                "epoch": epoch + 1,
                "state_dict": self.generator.state_dict(),
                "optimizer": self.optimizer.state_dict(),
                "best_pred": self.best_pred,
            },
        )

        global_step = epoch + 1
        for idx_unseen_class in args.unseen_classes_idx_metric:
            if len(saved_images[idx_unseen_class]) > 0:
                nb_image = len(saved_images[idx_unseen_class])
                if nb_image > args.saved_validation_images:
                    nb_image = args.saved_validation_images
                for i in range(nb_image):
                    self.summary.visualize_image_validation(
                        self.writer,
                        self.args.dataset,
                        saved_images[idx_unseen_class][i],
                        saved_target[idx_unseen_class][i],
                        saved_prediction[idx_unseen_class][i],
                        global_step,
                        name="validation_" + class_names[idx_unseen_class] +
                        "_" + str(i),
                        nb_image=1,
                    )

        self.evaluator.reset()