Exemple #1
0
    def __init__(self, args, cuda=None, train_id="None", logger=None):
        self.args = args
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')
        self.train_id = train_id
        self.logger = logger

        self.current_MIoU = 0
        self.best_MIou = 0
        self.best_source_MIou = 0
        self.current_epoch = 0
        self.current_iter = 0
        self.second_best_MIou = 0

        # set TensorboardX
        self.writer = SummaryWriter(self.args.checkpoint_dir)

        # Metric definition
        self.Eval = Eval(self.args.num_classes)

        # loss definition
        self.loss = nn.CrossEntropyLoss(weight=None, ignore_index= -1)
        self.loss.to(self.device)

        # model
        self.model, params = get_model(self.args)
        self.model = nn.DataParallel(self.model, device_ids=[0])
        self.model.to(self.device)

        if self.args.optim == "SGD":
            self.optimizer = torch.optim.SGD(
                params=params,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay
            )
        elif self.args.optim == "Adam":
            self.optimizer = torch.optim.Adam(params, betas=(0.9, 0.99), weight_decay=self.args.weight_decay)
        # dataloader
        if self.args.dataset=="cityscapes":
            self.dataloader = City_DataLoader(self.args)  
        elif self.args.dataset=="gta5":
            self.dataloader = GTA5_DataLoader(self.args)
        else:
            self.dataloader = SYNTHIA_DataLoader(self.args)
        self.dataloader.num_iterations = min(self.dataloader.num_iterations, ITER_MAX)
        print(self.args.iter_max, self.dataloader.num_iterations)
        self.epoch_num = ceil(self.args.iter_max / self.dataloader.num_iterations) if self.args.iter_stop is None else \
                            ceil(self.args.iter_stop / self.dataloader.num_iterations)
Exemple #2
0
    def __init__(self, args, cuda=None, train_id=None, logger=None):
        self.args = args
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.current_MIoU = 0
        self.best_MIou = 0
        self.current_epoch = 0
        self.current_iter = 0
        self.train_id = train_id
        self.logger = logger

        # set TensorboardX
        self.writer = SummaryWriter(self.args.checkpoint_dir)

        # Metric definition
        self.Eval = Eval(self.args.num_classes)

        # loss definition
        self.loss = nn.CrossEntropyLoss(ignore_index= -1)
        self.loss.to(self.device)

        # model
        self.model, params = get_model(self.args)
        self.model = nn.DataParallel(self.model, device_ids=[0])
        self.model.to(self.device)

        # load pretrained checkpoint
        if self.args.pretrained_ckpt_file is not None:
            path1 = os.path.join(*self.args.checkpoint_dir.split('/')[:-1], self.train_id + 'best.pth')
            path2 = self.args.pretrained_ckpt_file
            if os.path.exists(path1):
                pretrained_ckpt_file = path1
            elif os.path.exists(path2):
                pretrained_ckpt_file = path2
            else:
                raise AssertionError("no pretrained_ckpt_file")
            self.load_checkpoint(pretrained_ckpt_file)

        # dataloader
        self.dataloader = City_DataLoader(self.args) if self.args.dataset=="cityscapes" else GTA5_DataLoader(self.args)
        self.dataloader.val_loader = self.dataloader.data_loader
        self.dataloader.valid_iterations = min(self.dataloader.num_iterations, 500)
        self.epoch_num = ceil(self.args.iter_max / self.dataloader.num_iterations)
Exemple #3
0
    def __init__(self, args, cuda=None, train_id=None, logger=None):
        self.args = args
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.current_MIoU = 0
        self.best_MIou = 0
        self.current_epoch = 0
        self.current_iter = 0
        self.train_id = train_id
        self.logger = logger

        # set TensorboardX
        self.writer = SummaryWriter(self.args.checkpoint_dir)

        # Metric definition
        self.Eval = Eval(self.args.num_classes)

        # loss definition
        self.loss = nn.CrossEntropyLoss(ignore_index=-1)
        self.loss.to(self.device)

        # model
        self.model, params = get_model(self.args)
        self.model = nn.DataParallel(self.model, device_ids=[0])
        self.model.to(self.device)

        # load pretrained checkpoint
        if self.args.pretrained_ckpt_file is not None:
            path1 = os.path.join(*self.args.checkpoint_dir.split('/')[:-1],
                                 self.train_id + 'best.pth')
            path2 = self.args.pretrained_ckpt_file
            if os.path.exists(path1):
                pretrained_ckpt_file = path1
            elif os.path.exists(path2):
                pretrained_ckpt_file = path2
            else:
                raise AssertionError("no pretrained_ckpt_file")
            self.load_checkpoint(pretrained_ckpt_file)

        # dataloader
        self.dataloader = City_DataLoader(
            self.args
        ) if self.args.dataset == "cityscapes" else GTA5_DataLoader(self.args)
        if self.args.city_name != "None":
            target_data_set = CrossCity_Dataset(
                self.args,
                data_root_path=self.args.data_root_path,
                list_path=self.args.list_path,
                split='val',
                base_size=self.args.target_base_size,
                crop_size=self.args.target_crop_size,
                class_13=self.args.class_13)
            self.target_val_dataloader = data.DataLoader(
                target_data_set,
                batch_size=self.args.batch_size,
                shuffle=False,
                num_workers=self.args.data_loader_workers,
                pin_memory=self.args.pin_memory,
                drop_last=True)
            self.dataloader.val_loader = self.target_val_dataloader
            self.dataloader.valid_iterations = (
                len(target_data_set) +
                self.args.batch_size) // self.args.batch_size
        else:
            self.dataloader.val_loader = self.dataloader.data_loader
            self.dataloader.valid_iterations = min(
                self.dataloader.num_iterations, 500)
        self.epoch_num = ceil(self.args.iter_max /
                              self.dataloader.num_iterations)
Exemple #4
0
    def __init__(self, args, cuda=None, train_id="None", logger=None):
        self.args = args
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')
        self.train_id = train_id
        self.logger = logger

        self.current_MIoU = 0
        self.best_MIou = 0
        self.best_source_MIou = 0
        self.current_epoch = 0
        self.current_iter = 0
        self.second_best_MIou = 0

        # set TensorboardX
        self.writer = SummaryWriter(self.args.checkpoint_dir)

        # Metric definition
        self.Eval = Eval(self.args.num_classes)

        # loss definition
        self.loss = nn.CrossEntropyLoss(weight=None, ignore_index= -1)
        self.loss.to(self.device)

        # model
        self.model, params = get_model(self.args)
        self.model = nn.DataParallel(self.model)
        self.model.to(self.device)

        # support for FCN8s
        if self.args.backbone == "fcn8s_vgg" and self.args.optim == "SGD":
            self.args.optim = "Adam"
            print('WARNING: FCN8s requires Adam optimizer, but SGD was set. Switching to Adam.')

        if self.args.optim == "SGD":
            self.optimizer = torch.optim.SGD(
                params=params,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay
            )
        elif self.args.optim == "Adam":
            self.optimizer = torch.optim.Adam(params, betas=(0.9, 0.99), weight_decay=self.args.weight_decay)
        # dataloader
        if DEBUG: print('DEBUG: Loading training and validation datasets (for UDA only val one is used, but it is overwritten)')
        if self.args.dataset=="cityscapes":
            self.dataloader = City_DataLoader(self.args)  
        elif self.args.dataset=="gta5":
            self.dataloader = GTA5_DataLoader(self.args)
        else:
            self.dataloader = SYNTHIA_DataLoader(self.args)

        ###
        use_target_val = True
        if use_target_val:
            if DEBUG: print('DEBUG: Overwriting validation set, using target set instead of source one')
            target_data_set = City_Dataset(args,
                                           data_root_path=datasets_path['cityscapes']['data_root_path'],
                                           list_path=datasets_path['cityscapes']['list_path'],
                                           split='test',
                                           base_size=args.target_base_size,
                                           crop_size=args.target_crop_size,
                                           class_16=args.class_16)
            self.target_val_dataloader = data.DataLoader(target_data_set,
                                                         batch_size=self.args.batch_size,
                                                         shuffle=False,
                                                         num_workers=self.args.data_loader_workers,
                                                         pin_memory=self.args.pin_memory,
                                                         drop_last=True)
            self.dataloader.val_loader = self.target_val_dataloader
        ###


        self.dataloader.num_iterations = min(self.dataloader.num_iterations, ITER_MAX)
        self.epoch_num = ceil(self.args.iter_max / self.dataloader.num_iterations) if self.args.iter_stop is None else \
                            ceil(self.args.iter_stop / self.dataloader.num_iterations)
Exemple #5
0
    def __init__(self, args, config, cuda=None):
        self.args = args
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu
        self.config = config
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.best_MIou = 0
        self.current_epoch = 0
        self.epoch_num = self.config.epoch_num
        self.current_iter = 0

        self.writer = SummaryWriter()

        # Metric definition
        self.Eval = Eval(self.config.num_classes)

        # loss definition
        if args.loss_weight:
            classes_weights_path = os.path.join(
                self.config.classes_weight,
                self.args.dataset + 'classes_weights_log.npy')
            print(classes_weights_path)
            if not os.path.isfile(classes_weights_path):
                logger.info('calculating class weights...')
                calculate_weigths_labels(self.config)
            class_weights = np.load(classes_weights_path)
            pprint.pprint(class_weights)
            weight = torch.from_numpy(class_weights.astype(np.float32))
            logger.info('loading class weights successfully!')
        else:
            weight = None

        self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
        self.loss.to(self.device)

        # model
        self.model = DeepLab(output_stride=self.args.output_stride,
                             class_num=self.config.num_classes,
                             pretrained=self.args.imagenet_pretrained,
                             bn_momentum=self.args.bn_momentum,
                             freeze_bn=self.args.freeze_bn)
        self.model = nn.DataParallel(self.model, device_ids=range(4))
        patch_replication_callback(self.model)
        self.model.to(self.device)

        self.optimizer = torch.optim.SGD(
            params=[
                {
                    "params": self.get_params(self.model.module, key="1x"),
                    "lr": self.args.lr,
                },
                {
                    "params": self.get_params(self.model.module, key="10x"),
                    "lr": 10 * self.args.lr,
                },
            ],
            momentum=self.config.momentum,
            # dampening=self.config.dampening,
            weight_decay=self.config.weight_decay,
            # nesterov=self.config.nesterov
        )
        # dataloader
        self.dataloader = City_DataLoader(self.args, self.config)
        if self.args.validation == True:
            self.val_data = City_DataLoader(self.args, 'val')