示例#1
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        """
            Get dataLoader
        """
        #         config = get_config(args.config)
        #         vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, _, cls_map, cls_map_test = get_split(config)
        #         assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1)
        #         print('seen_classes', vals_cls)
        #         print('novel_classes', valu_cls)
        #         print('all_labels', all_labels)
        #         print('visible_classes', visible_classes)
        #         print('visible_classes_test', visible_classes_test)
        #         print('train', train[:10], len(train))
        #         print('val', val[:10], len(val))
        #         print('cls_map', cls_map)
        #         print('cls_map_test', cls_map_test)

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            _,
            self.nclass,
        ) = make_data_loader(args,
                             load_embedding=args.load_embedding,
                             w2c_size=args.w2c_size,
                             **kwargs)
        print('self.nclass', self.nclass)  # 33

        model = DeepLab(
            num_classes=self.nclass,
            output_stride=args.out_stride,
            sync_bn=args.sync_bn,
            freeze_bn=args.freeze_bn,
            global_avg_pool_bn=args.global_avg_pool_bn,
            imagenet_pretrained_path=args.imagenet_pretrained_path,
        )

        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Generator
        generator = GMMNnetwork(args.noise_dim, args.embed_dim,
                                args.hidden_size, args.feature_dim)
        optimizer_generator = torch.optim.Adam(generator.parameters(),
                                               lr=args.lr_generator)

        class_weight = torch.ones(self.nclass)
        class_weight[args.unseen_classes_idx_metric] = args.unseen_weight
        if args.cuda:
            class_weight = class_weight.cuda()

        self.criterion = SegmentationLosses(
            weight=class_weight,
            cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.criterion_generator = GMMNLoss(sigma=[2, 5, 10, 20, 40, 80],
                                            cuda=args.cuda).build_loss()
        self.generator, self.optimizer_generator = generator, optimizer_generator

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass, args.seen_classes_idx_metric,
                                   args.unseen_classes_idx_metric)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            self.generator = self.generator.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    f"=> no checkpoint found at '{args.resume}'")
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']

            if args.random_last_layer:
                checkpoint["state_dict"][
                    "decoder.pred_conv.weight"] = torch.rand((
                        self.nclass,
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[1],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[2],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[3],
                    ))
                checkpoint["state_dict"][
                    "decoder.pred_conv.bias"] = torch.rand(self.nclass)

            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])

            # self.best_pred = checkpoint['best_pred']
            print(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
示例#2
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        """
            Get dataLoader
        """
        config = get_config(args.config)
        vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, visibility_mask, cls_map, cls_map_test = get_split(
            config)
        assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] -
                1)

        dataset = get_dataset(config['DATAMODE'])(
            train=train,
            test=None,
            root=config['ROOT'],
            split=config['SPLIT']['TRAIN'],
            base_size=513,
            crop_size=config['IMAGE']['SIZE']['TRAIN'],
            mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'],
                  config['IMAGE']['MEAN']['R']),
            warp=config['WARP_IMAGE'],
            scale=(0.5, 1.5),
            flip=True,
            visibility_mask=visibility_mask)
        print('train dataset:', len(dataset))

        loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['BATCH_SIZE']['TRAIN'],
            num_workers=config['NUM_WORKERS'],
            sampler=sampler)

        dataset_test = get_dataset(config['DATAMODE'])(
            train=None,
            test=val,
            root=config['ROOT'],
            split=config['SPLIT']['TEST'],
            base_size=513,
            crop_size=config['IMAGE']['SIZE']['TEST'],
            mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'],
                  config['IMAGE']['MEAN']['R']),
            warp=config['WARP_IMAGE'],
            scale=None,
            flip=False)
        print('test dataset:', len(dataset_test))

        loader_test = torch.utils.data.DataLoader(
            dataset=dataset_test,
            batch_size=config['BATCH_SIZE']['TEST'],
            num_workers=config['NUM_WORKERS'],
            shuffle=False)

        self.train_loader = loader
        self.val_loader = loader_test
        self.nclass = 21

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            _,
            self.nclass,
        ) = make_data_loader(args,
                             load_embedding=args.load_embedding,
                             w2c_size=args.w2c_size,
                             **kwargs)
        print('self.nclass', self.nclass)

        # Define network
        model = DeepLab(
            num_classes=self.nclass,
            output_stride=args.out_stride,
            sync_bn=args.sync_bn,
            freeze_bn=args.freeze_bn,
            global_avg_pool_bn=args.global_avg_pool_bn,
            imagenet_pretrained_path=args.imagenet_pretrained_path,
        )
        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Generator
        generator = GMMNnetwork(args.noise_dim, args.embed_dim,
                                args.hidden_size, args.feature_dim)
        optimizer_generator = torch.optim.Adam(generator.parameters(),
                                               lr=args.lr_generator)

        class_weight = torch.ones(self.nclass)
        class_weight[args.unseen_classes_idx_metric] = args.unseen_weight
        if args.cuda:
            class_weight = class_weight.cuda()

        self.criterion = SegmentationLosses(
            weight=class_weight,
            cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.criterion_generator = GMMNLoss(sigma=[2, 5, 10, 20, 40, 80],
                                            cuda=args.cuda).build_loss()
        self.generator, self.optimizer_generator = generator, optimizer_generator

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass, args.seen_classes_idx_metric,
                                   args.unseen_classes_idx_metric)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            self.generator = self.generator.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    f"=> no checkpoint found at '{args.resume}'")
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']

            if args.random_last_layer:
                checkpoint["state_dict"][
                    "decoder.pred_conv.weight"] = torch.rand((
                        self.nclass,
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[1],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[2],
                        checkpoint["state_dict"]
                        ["decoder.pred_conv.weight"].shape[3],
                    ))
                checkpoint["state_dict"][
                    "decoder.pred_conv.bias"] = torch.rand(self.nclass)

            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])

            # self.best_pred = checkpoint['best_pred']
            print(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0