Example #1
0
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # get valid dataset images and targets
        self.image_paths, self.mask_paths = _get_city_pairs(cfg["train"]["cityscapes_root"], "val")

        # create network
        self.model = ICNet(nclass = 5, backbone='resnet50').to(self.device)
        
        # load ckpt
        pretrained_net = torch.load(cfg["test"]["ckpt_path"])
        self.model.load_state_dict(pretrained_net)
        
        # evaluation metrics
        self.metric = SegmentationMetric(5)
Example #2
0
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        # get valid dataset images and targets
        self.loader = SUNRGBDLoader()

        # create network
        self.model = ICNet(nclass=19, backbone='resnet50').to(self.device)

        # load ckpt
        pretrained_net = torch.load(cfg["test"]["ckpt_path"])
        self.model.load_state_dict(pretrained_net)

        # evaluation metrics
        self.metric = SegmentationMetric(19)
Example #3
0
class Evaluator(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        # get valid dataset images and targets
        self.loader = SUNRGBDLoader()

        # create network
        self.model = ICNet(nclass=19, backbone='resnet50').to(self.device)

        # load ckpt
        pretrained_net = torch.load(cfg["test"]["ckpt_path"])
        self.model.load_state_dict(pretrained_net)

        # evaluation metrics
        self.metric = SegmentationMetric(19)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        model = self.model

        logger.info("Start validation, Total sample: {:d}".format(
            len(self.image_paths)))
        list_time = []
        lsit_pixAcc = []
        list_mIoU = []

        for i in range(len(self.image_paths)):

            image = Image.open(self.image_paths[i]).convert(
                'RGB')  # image shape: (W,H,3)
            mask = Image.open(self.mask_paths[i])  # mask shape: (W,H)

            image = self._img_transform(image)  # image shape: (3,H,W) [0,1]
            mask = self._mask_transform(mask)  # mask shape: (H,w)

            image = image.to(self.device)
            mask = mask.to(self.device)

            image = torch.unsqueeze(image, 0)  # image shape: (1,3,H,W) [0,1]

            with torch.no_grad():
                start_time = time.time()
                outputs = model(image)
                end_time = time.time()
                step_time = end_time - start_time
            self.metric.update(outputs[0], mask)
            pixAcc, mIoU = self.metric.get()
            list_time.append(step_time)
            lsit_pixAcc.append(pixAcc)
            list_mIoU.append(mIoU)
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}, time: {:.3f}s"
                .format(i + 1, pixAcc * 100, mIoU * 100, step_time))

            filename = os.path.basename(self.image_paths[i])
            prefix = filename.split('.')[0]

            # save pred
            pred = torch.argmax(outputs[0], 1)
            pred = pred.cpu().data.numpy()
            pred = pred.squeeze(0)
            pred = get_color_pallete(pred, "citys")
            pred.save(
                os.path.join(outdir, prefix + "_mIoU_{:.3f}.png".format(mIoU)))

            # save image
            image = Image.open(self.image_paths[i]).convert(
                'RGB')  # image shape: (W,H,3)
            image.save(os.path.join(outdir, prefix + '_src.png'))

            # save target
            mask = Image.open(self.mask_paths[i])  # mask shape: (W,H)
            mask = self._class_to_index(np.array(mask).astype('int32'))
            mask = get_color_pallete(mask, "citys")
            mask.save(os.path.join(outdir, prefix + '_label.png'))

        average_pixAcc = sum(lsit_pixAcc) / len(lsit_pixAcc)
        average_mIoU = sum(list_mIoU) / len(list_mIoU)
        average_time = sum(list_time) / len(list_time)
        self.current_mIoU = average_mIoU
        logger.info(
            "Evaluate: Average mIoU: {:.3f}, Average pixAcc: {:.3f}, Average time: {:.3f}"
            .format(average_mIoU, average_pixAcc, average_time))

    def _img_transform(self, image):
        image_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        image = image_transform(image)
        return image

    def _mask_transform(self, mask):
        mask = self._class_to_index(np.array(mask).astype('int32'))
        return torch.LongTensor(np.array(mask).astype('int32'))

    def _class_to_index(self, mask):
        # assert the value
        values = np.unique(mask)
        self._key = np.array([
            -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, -1, -1, 2, 3, 4, -1, -1, -1,
            5, -1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, 16, 17, 18
        ])
        self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
        for value in values:
            assert (value in self._mapping)
        # 获取mask中各像素值对应于_mapping的索引
        index = np.digitize(mask.ravel(), self._mapping, right=True)
        # 依据上述索引index,根据_key,得到对应的mask图
        return self._key[index].reshape(mask.shape)
    )
elif str(args.model) == 'DeepLabV3Plus':
    model = DeepLabV3Plus(
        backbone="resnet18",
        num_classes=2,
    pretrained_backbone=None
    )
elif str(args.model) == 'UNetPlus':
    model = UNetPlus(
        backbone="resnet18",
        num_classes=2,
    )
else:
    model = ICNet(
    backbone="resnet18",
    num_classes=2,
	pretrained_backbone=None
)


if args.use_cuda:
	model = model.cuda()
trained_dict = torch.load(args.checkpoint, map_location="cpu")['state_dict']
model.load_state_dict(trained_dict, strict=False)
model.eval()

def path_leaf(path):
  import ntpath
  head, tail = ntpath.split(path)
  return tail or ntpath.basename(head)
Example #5
0
# Background
if args.bg is not None:
    BACKGROUND = cv2.imread(args.bg)[..., ::-1]
    BACKGROUND = cv2.resize(BACKGROUND, (W, H), interpolation=cv2.INTER_LINEAR)
    KERNEL_SZ = 25
    SIGMA = 0

# Alpha transperency
else:
    COLOR1 = [255, 0, 0]
    COLOR2 = [0, 0, 255]

#------------------------------------------------------------------------------
#	Create model and load weights
#------------------------------------------------------------------------------
model = ICNet(backbone="resnet18", num_classes=2, pretrained_backbone=None)
if args.use_cuda:
    model = model.cuda()
trained_dict = torch.load(args.checkpoint, map_location="cpu")['state_dict']
model.load_state_dict(trained_dict, strict=False)
model.eval()

#------------------------------------------------------------------------------
#   Predict frames
#------------------------------------------------------------------------------
i = 0
while (cap.isOpened()):
    # Read frame from camera
    start_time = time()
    ret, frame = cap.read()
    if ret != True:
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.dataparallel = torch.cuda.device_count() > 1

        # dataset and dataloader
        train_dataset = CityscapesDataset(root=cfg["train"]["cityscapes_root"],
                                          split='train',
                                          base_size=cfg["model"]["base_size"],
                                          crop_size=cfg["model"]["crop_size"])
        val_dataset = CityscapesDataset(root=cfg["train"]["cityscapes_root"],
                                        split='val',
                                        base_size=cfg["model"]["base_size"],
                                        crop_size=cfg["model"]["crop_size"])
        self.train_dataloader = data.DataLoader(
            dataset=train_dataset,
            batch_size=cfg["train"]["train_batch_size"],
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            drop_last=False)
        self.val_dataloader = data.DataLoader(
            dataset=val_dataset,
            batch_size=cfg["train"]["valid_batch_size"],
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            drop_last=False)

        self.iters_per_epoch = len(self.train_dataloader)
        self.max_iters = cfg["train"]["epochs"] * self.iters_per_epoch

        # create network
        self.model = ICNet(nclass=train_dataset.NUM_CLASS,
                           backbone='resnet50').to(self.device)

        # create criterion
        self.criterion = ICNetLoss(ignore_index=train_dataset.IGNORE_INDEX).to(
            self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': cfg["optimizer"]["init_lr"]
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    cfg["optimizer"]["init_lr"] * 10
                })
        self.optimizer = torch.optim.SGD(
            params=params_list,
            lr=cfg["optimizer"]["init_lr"],
            momentum=cfg["optimizer"]["momentum"],
            weight_decay=cfg["optimizer"]["weight_decay"])
        # self.optimizer = torch.optim.SGD(params = self.model.parameters(),
        #                                  lr = cfg["optimizer"]["init_lr"],
        #                                  momentum=cfg["optimizer"]["momentum"],
        #                                  weight_decay=cfg["optimizer"]["weight_decay"])

        # lr scheduler
        self.lr_scheduler = IterationPolyLR(self.optimizer,
                                            max_iters=self.max_iters,
                                            power=0.9)
        # dataparallel
        if (self.dataparallel):
            self.model = nn.DataParallel(self.model)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.NUM_CLASS)

        self.current_mIoU = 0.0
        self.best_mIoU = 0.0

        self.epochs = cfg["train"]["epochs"]
        self.current_epoch = 0
        self.current_iteration = 0
class Trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.dataparallel = torch.cuda.device_count() > 1

        # dataset and dataloader
        train_dataset = CityscapesDataset(root=cfg["train"]["cityscapes_root"],
                                          split='train',
                                          base_size=cfg["model"]["base_size"],
                                          crop_size=cfg["model"]["crop_size"])
        val_dataset = CityscapesDataset(root=cfg["train"]["cityscapes_root"],
                                        split='val',
                                        base_size=cfg["model"]["base_size"],
                                        crop_size=cfg["model"]["crop_size"])
        self.train_dataloader = data.DataLoader(
            dataset=train_dataset,
            batch_size=cfg["train"]["train_batch_size"],
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            drop_last=False)
        self.val_dataloader = data.DataLoader(
            dataset=val_dataset,
            batch_size=cfg["train"]["valid_batch_size"],
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            drop_last=False)

        self.iters_per_epoch = len(self.train_dataloader)
        self.max_iters = cfg["train"]["epochs"] * self.iters_per_epoch

        # create network
        self.model = ICNet(nclass=train_dataset.NUM_CLASS,
                           backbone='resnet50').to(self.device)

        # create criterion
        self.criterion = ICNetLoss(ignore_index=train_dataset.IGNORE_INDEX).to(
            self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': cfg["optimizer"]["init_lr"]
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    cfg["optimizer"]["init_lr"] * 10
                })
        self.optimizer = torch.optim.SGD(
            params=params_list,
            lr=cfg["optimizer"]["init_lr"],
            momentum=cfg["optimizer"]["momentum"],
            weight_decay=cfg["optimizer"]["weight_decay"])
        # self.optimizer = torch.optim.SGD(params = self.model.parameters(),
        #                                  lr = cfg["optimizer"]["init_lr"],
        #                                  momentum=cfg["optimizer"]["momentum"],
        #                                  weight_decay=cfg["optimizer"]["weight_decay"])

        # lr scheduler
        self.lr_scheduler = IterationPolyLR(self.optimizer,
                                            max_iters=self.max_iters,
                                            power=0.9)
        # dataparallel
        if (self.dataparallel):
            self.model = nn.DataParallel(self.model)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.NUM_CLASS)

        self.current_mIoU = 0.0
        self.best_mIoU = 0.0

        self.epochs = cfg["train"]["epochs"]
        self.current_epoch = 0
        self.current_iteration = 0

    def train(self):
        epochs, max_iters = self.epochs, self.max_iters
        log_per_iters = self.cfg["train"]["log_iter"]
        val_per_iters = self.cfg["train"]["val_epoch"] * self.iters_per_epoch

        start_time = time.time()
        logger.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()

        for _ in range(self.epochs):
            self.current_epoch += 1
            lsit_pixAcc = []
            list_mIoU = []
            list_loss = []
            self.metric.reset()
            for i, (images, targets, _) in enumerate(self.train_dataloader):
                self.current_iteration += 1

                self.lr_scheduler.step()

                images = images.to(self.device)
                targets = targets.to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                self.metric.update(outputs[0], targets)
                pixAcc, mIoU = self.metric.get()
                lsit_pixAcc.append(pixAcc)
                list_mIoU.append(mIoU)
                list_loss.append(loss.item())

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                eta_seconds = (
                    (time.time() - start_time) / self.current_iteration) * (
                        max_iters - self.current_iteration)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                if self.current_iteration % log_per_iters == 0:
                    logger.info(
                        "Epochs: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || mIoU: {:.4f} || Cost Time: {} || Estimated Time: {}"
                        .format(
                            self.current_epoch, self.epochs,
                            self.current_iteration,
                            max_iters, self.optimizer.param_groups[0]['lr'],
                            loss.item(), mIoU,
                            str(
                                datetime.timedelta(seconds=int(time.time() -
                                                               start_time))),
                            eta_string))

            average_pixAcc = sum(lsit_pixAcc) / len(lsit_pixAcc)
            average_mIoU = sum(list_mIoU) / len(list_mIoU)
            average_loss = sum(list_loss) / len(list_loss)
            logger.info(
                "Epochs: {:d}/{:d}, Average loss: {:.3f}, Average mIoU: {:.3f}, Average pixAcc: {:.3f}"
                .format(self.current_epoch, self.epochs, average_loss,
                        average_mIoU, average_pixAcc))

            if self.current_iteration % val_per_iters == 0:
                self.validation()
                self.model.train()

        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self):
        is_best = False
        self.metric.reset()
        if self.dataparallel:
            model = self.model.module
        else:
            model = self.model
        model.eval()
        lsit_pixAcc = []
        list_mIoU = []
        list_loss = []
        for i, (image, targets, filename) in enumerate(self.val_dataloader):
            image = image.to(self.device)
            targets = targets.to(self.device)

            with torch.no_grad():
                outputs = model(image)
                loss = self.criterion(outputs, targets)
            self.metric.update(outputs[0], targets)
            pixAcc, mIoU = self.metric.get()
            lsit_pixAcc.append(pixAcc)
            list_mIoU.append(mIoU)
            list_loss.append(loss.item())

        average_pixAcc = sum(lsit_pixAcc) / len(lsit_pixAcc)
        average_mIoU = sum(list_mIoU) / len(list_mIoU)
        average_loss = sum(list_loss) / len(list_loss)
        self.current_mIoU = average_mIoU
        logger.info(
            "Validation: Average loss: {:.3f}, Average mIoU: {:.3f}, Average pixAcc: {:.3f}"
            .format(average_loss, average_mIoU, average_pixAcc))

        if self.current_mIoU > self.best_mIoU:
            is_best = True
            self.best_mIoU = self.current_mIoU
        if is_best:
            save_checkpoint(self.model, self.cfg, self.current_epoch, is_best,
                            self.current_mIoU, self.dataparallel)
Example #8
0
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.dataparallel = torch.cuda.device_count() > 1

        # dataset and dataloader
        # train_dataset = CityscapesDataset(root = cfg["train"]["cityscapes_root"],
        #                                   split='train',
        #                                   base_size=cfg["model"]["base_size"],
        #                                   crop_size=cfg["model"]["crop_size"])
        # val_dataset = CityscapesDataset(root = cfg["train"]["cityscapes_root"],
        #                                 split='val',
        #                                 base_size=cfg["model"]["base_size"],
        #                                 crop_size=cfg["model"]["crop_size"])
        train_dataset = SUNRGBDLoader(root=cfg["train"]["data_path"],
                                      split="training",
                                      is_transform=True,
                                      img_size=(cfg['train']['img_rows'],
                                                cfg['train']['img_cols']),
                                      img_norm=True)
        val_dataset = SUNRGBDLoader(root=cfg["train"]["data_path"],
                                    split="val",
                                    is_transform=True,
                                    img_size=(cfg['train']['img_rows'],
                                              cfg['train']['img_cols']),
                                    img_norm=True)
        self.train_dataloader = data.DataLoader(
            dataset=train_dataset,
            batch_size=cfg["train"]["train_batch_size"],
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=False)
        self.val_dataloader = data.DataLoader(
            dataset=val_dataset,
            batch_size=cfg["train"]["valid_batch_size"],
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False)

        self.iters_per_epoch = len(self.train_dataloader)
        self.max_iters = cfg["train"]["epochs"] * self.iters_per_epoch

        # create network
        self.model = ICNet(nclass=train_dataset.n_classes,
                           backbone='resnet50').to(self.device)

        # create criterion
        # self.criterion = ICNetLoss(ignore_index=train_dataset.IGNORE_INDEX).to(self.device)
        self.criterion = ICNetLoss(ignore_index=-1).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': cfg["optimizer"]["init_lr"]
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    cfg["optimizer"]["init_lr"] * 10
                })
        self.optimizer = torch.optim.SGD(
            params=params_list,
            lr=cfg["optimizer"]["init_lr"],
            momentum=cfg["optimizer"]["momentum"],
            weight_decay=cfg["optimizer"]["weight_decay"])
        # self.optimizer = torch.optim.SGD(params = self.model.parameters(),
        #                                  lr = cfg["optimizer"]["init_lr"],
        #                                  momentum=cfg["optimizer"]["momentum"],
        #                                  weight_decay=cfg["optimizer"]["weight_decay"])

        # lr scheduler
        # self.lr_scheduler = IterationPolyLR(self.optimizer, max_iters=self.max_iters, power=0.9)
        self.lr_scheduler = PloyStepLR(self.optimizer, milestone=3500)
        # self.lr_scheduler = ConstantLR(self.optimizer)

        # dataparallel
        if self.dataparallel:
            self.model = nn.DataParallel(self.model)

        # evaluation metrics
        self.metric = runningScore(train_dataset.n_classes)

        self.current_mIoU = 0.0
        self.best_mIoU = 0.0

        self.epochs = cfg["train"]["epochs"]
        self.current_epoch = 0
        self.current_iteration = 0

        if cfg["train"]["resume"] is not None:
            if os.path.isfile(cfg["train"]["resume"]):
                logger.info(
                    "Loading model and optimizer from checkpoint '{}'".format(
                        cfg["train"]["resume"]))
                checkpoint = torch.load(cfg["train"]["resume"])
                self.model.load_state_dict(checkpoint["model_state"])
                self.optimizer.load_state_dict(checkpoint["optimizer_state"])
                self.lr_scheduler.load_state_dict(
                    checkpoint["scheduler_state"])
                self.current_epoch = checkpoint["epoch"]
                logger.info("Loaded checkpoint '{}' (iter {})".format(
                    cfg["train"]["resume"], checkpoint["epoch"]))
            else:
                logger.info("No checkpoint found at '{}'".format(
                    cfg["train"]["resume"]))
Example #9
0
class Trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.dataparallel = torch.cuda.device_count() > 1

        # dataset and dataloader
        # train_dataset = CityscapesDataset(root = cfg["train"]["cityscapes_root"],
        #                                   split='train',
        #                                   base_size=cfg["model"]["base_size"],
        #                                   crop_size=cfg["model"]["crop_size"])
        # val_dataset = CityscapesDataset(root = cfg["train"]["cityscapes_root"],
        #                                 split='val',
        #                                 base_size=cfg["model"]["base_size"],
        #                                 crop_size=cfg["model"]["crop_size"])
        train_dataset = SUNRGBDLoader(root=cfg["train"]["data_path"],
                                      split="training",
                                      is_transform=True,
                                      img_size=(cfg['train']['img_rows'],
                                                cfg['train']['img_cols']),
                                      img_norm=True)
        val_dataset = SUNRGBDLoader(root=cfg["train"]["data_path"],
                                    split="val",
                                    is_transform=True,
                                    img_size=(cfg['train']['img_rows'],
                                              cfg['train']['img_cols']),
                                    img_norm=True)
        self.train_dataloader = data.DataLoader(
            dataset=train_dataset,
            batch_size=cfg["train"]["train_batch_size"],
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=False)
        self.val_dataloader = data.DataLoader(
            dataset=val_dataset,
            batch_size=cfg["train"]["valid_batch_size"],
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False)

        self.iters_per_epoch = len(self.train_dataloader)
        self.max_iters = cfg["train"]["epochs"] * self.iters_per_epoch

        # create network
        self.model = ICNet(nclass=train_dataset.n_classes,
                           backbone='resnet50').to(self.device)

        # create criterion
        # self.criterion = ICNetLoss(ignore_index=train_dataset.IGNORE_INDEX).to(self.device)
        self.criterion = ICNetLoss(ignore_index=-1).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': cfg["optimizer"]["init_lr"]
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    cfg["optimizer"]["init_lr"] * 10
                })
        self.optimizer = torch.optim.SGD(
            params=params_list,
            lr=cfg["optimizer"]["init_lr"],
            momentum=cfg["optimizer"]["momentum"],
            weight_decay=cfg["optimizer"]["weight_decay"])
        # self.optimizer = torch.optim.SGD(params = self.model.parameters(),
        #                                  lr = cfg["optimizer"]["init_lr"],
        #                                  momentum=cfg["optimizer"]["momentum"],
        #                                  weight_decay=cfg["optimizer"]["weight_decay"])

        # lr scheduler
        # self.lr_scheduler = IterationPolyLR(self.optimizer, max_iters=self.max_iters, power=0.9)
        self.lr_scheduler = PloyStepLR(self.optimizer, milestone=3500)
        # self.lr_scheduler = ConstantLR(self.optimizer)

        # dataparallel
        if self.dataparallel:
            self.model = nn.DataParallel(self.model)

        # evaluation metrics
        self.metric = runningScore(train_dataset.n_classes)

        self.current_mIoU = 0.0
        self.best_mIoU = 0.0

        self.epochs = cfg["train"]["epochs"]
        self.current_epoch = 0
        self.current_iteration = 0

        if cfg["train"]["resume"] is not None:
            if os.path.isfile(cfg["train"]["resume"]):
                logger.info(
                    "Loading model and optimizer from checkpoint '{}'".format(
                        cfg["train"]["resume"]))
                checkpoint = torch.load(cfg["train"]["resume"])
                self.model.load_state_dict(checkpoint["model_state"])
                self.optimizer.load_state_dict(checkpoint["optimizer_state"])
                self.lr_scheduler.load_state_dict(
                    checkpoint["scheduler_state"])
                self.current_epoch = checkpoint["epoch"]
                logger.info("Loaded checkpoint '{}' (iter {})".format(
                    cfg["train"]["resume"], checkpoint["epoch"]))
            else:
                logger.info("No checkpoint found at '{}'".format(
                    cfg["train"]["resume"]))

    def train(self):
        epochs, max_iters = self.epochs, self.max_iters
        log_per_iters = self.cfg["train"]["log_iter"]
        val_per_iters = self.cfg["train"]["val_epoch"] * self.iters_per_epoch

        time_meter = averageMeter()
        train_loss_meter = averageMeter()
        logger.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()

        # for _ in range(self.epochs):
        while self.current_epoch <= self.epochs:
            self.current_epoch += 1
            # self.lr_scheduler.step()
            # for i, (images, targets, _) in enumerate(self.train_dataloader):
            for i, (images, targets) in enumerate(self.train_dataloader):
                self.current_iteration += 1
                start_time = time.time()
                self.lr_scheduler.step()

                images = images.to(self.device)
                targets = targets.to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, targets)
                pred = outputs[0].data.max(1)[1].cpu().numpy()
                gt = targets.data.cpu().numpy()

                self.metric.update(gt, pred)
                train_loss_meter.update(loss.item())

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                time_meter.update(time.time() - start_time)

                if self.current_iteration % log_per_iters == 0:
                    eta_seconds = time_meter.avg * (max_iters -
                                                    self.current_iteration)
                    eta_string = str(
                        datetime.timedelta(seconds=int(eta_seconds)))
                    logger.info(
                        "Epochs: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                        .format(
                            self.current_epoch, self.epochs,
                            self.current_iteration, max_iters,
                            self.optimizer.param_groups[0]['lr'], loss.item(),
                            str(datetime.timedelta(
                                seconds=int(time_meter.val))), eta_string))
                    print(
                        "Epochs: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                        .format(
                            self.current_epoch, self.epochs,
                            self.current_iteration, max_iters,
                            self.optimizer.param_groups[0]['lr'], loss.item(),
                            str(datetime.timedelta(
                                seconds=int(time_meter.val))), eta_string))
                    time_meter.reset()

            writer.add_scalar("loss/train_loss", train_loss_meter.avg,
                              self.current_epoch)
            score, class_iou = self.metric.get_scores()
            for k, v in score.items():
                print(k, v)
                logger.info("{}: {}".format(k, v))
                writer.add_scalar("train_metrics/{}".format(k), v,
                                  self.current_epoch)

            for k, v in class_iou.items():
                logger.info("{}: {}".format(k, v))

            self.metric.reset()
            train_loss_meter.reset()

            if self.current_iteration % val_per_iters == 0:
                self.validation()
                self.model.train()

    def validation(self):
        is_best = False
        if self.dataparallel:
            model = self.model.module
        else:
            model = self.model
        model.eval()
        val_loss_meter = averageMeter()
        # for i, (image, targets, filename) in enumerate(self.val_dataloader):
        for i, (image, targets) in tqdm(enumerate(self.val_dataloader)):
            image = image.to(self.device)
            targets = targets.to(self.device)

            with torch.no_grad():
                outputs = model(image)
                loss = self.criterion(outputs, targets)
                pred = outputs[0].data.max(1)[1].cpu().numpy()
                gt = targets.data.cpu().numpy()

            self.metric.update(gt, pred)
            val_loss_meter.update(loss.item())

        logger.info("epoch %d Loss: %.4f" %
                    (self.current_epoch, val_loss_meter.avg))
        writer.add_scalar("loss/val_loss", val_loss_meter.avg,
                          self.current_epoch)
        score, class_iou = self.metric.get_scores()
        for k, v in score.items():
            print(k, v)
            logger.info("{}: {}".format(k, v))
            writer.add_scalar("val_metrics/{}".format(k), v,
                              self.current_epoch)

        for k, v in class_iou.items():
            logger.info("{}: {}".format(k, v))
            writer.add_scalar("val_metrics/cls_{}".format(k), v,
                              self.current_epoch)
        self.current_mIoU = score["Mean IoU : \t"]
        logger.info(
            "Validation: Average loss: {:.3f}, mIoU: {:.3f}, mean pixAcc: {:.3f}"
            .format(val_loss_meter.avg, self.current_mIoU,
                    score["Mean Acc : \t"]))
        self.metric.reset()
        val_loss_meter.reset()

        if self.current_mIoU > self.best_mIoU:
            is_best = True
            self.best_mIoU = self.current_mIoU
        if is_best:
            self.save_checkpoint()

    def save_checkpoint(self):
        """Save Checkpoint"""
        # directory = os.path.expanduser(cfg["train"]["ckpt_dir"])
        directory = logdir
        if not os.path.exists(directory):
            os.makedirs(directory)
        filename = '{}_{}_{}_{:.3f}.pth'.format(cfg["model"]["name"],
                                                cfg["model"]["backbone"],
                                                self.current_epoch,
                                                self.current_mIoU)
        filename = os.path.join(directory, filename)
        if self.dataparallel:
            model = self.model.module
        else:
            model = self.model

        state = {
            "epoch": self.current_epoch,
            "model_state": model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.lr_scheduler.state_dict(),
            "best_iou": self.best_mIoU,
        }
        # best_filename = os.path.join(directory, filename)
        torch.save(state, filename)