def __init__(self, args, config, cuda=None):
        self.args = args
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu
        self.config = config
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.best_MIou = 0
        self.current_epoch = 0
        self.epoch_num = self.config.epoch_num
        self.current_iter = 0

        self.writer = SummaryWriter()

        # path definition
        self.val_list_filepath = os.path.join(
            args.data_root_path, 'VOC2012/ImageSets/Segmentation/val.txt')
        self.gt_filepath = os.path.join(args.data_root_path,
                                        'VOC2012/SegmentationClass/')
        self.pre_filepath = os.path.join(args.data_root_path,
                                         'VOC2012/JPEGImages/')

        # Metric definition
        self.Eval = Eval(self.config.num_classes)

        # loss definition
        if args.loss_weight:
            classes_weights_path = os.path.join(
                self.config.classes_weight,
                self.args.dataset + 'classes_weights_log.npy')
            print(classes_weights_path)
            if not os.path.isfile(classes_weights_path):
                logger.info('calculating class weights...')
                calculate_weigths_labels(self.config)
            class_weights = np.load(classes_weights_path)
            pprint.pprint(class_weights)
            weight = torch.from_numpy(class_weights.astype(np.float32))
            logger.info('loading class weights successfully!')
        else:
            weight = None

        self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
        self.loss.to(self.device)

        # model
        self.model = DeepLab(output_stride=self.args.output_stride,
                             class_num=self.config.num_classes,
                             pretrained=self.args.imagenet_pretrained,
                             bn_momentum=self.args.bn_momentum,
                             freeze_bn=self.args.freeze_bn)
        self.model = nn.DataParallel(self.model, device_ids=range(4))
        patch_replication_callback(self.model)
        self.model.to(self.device)

        self.optimizer = torch.optim.SGD(
            params=[
                {
                    "params": self.get_params(self.model.module, key="1x"),
                    "lr": self.args.lr,
                },
                {
                    "params": self.get_params(self.model.module, key="10x"),
                    "lr": 10 * self.args.lr,
                },
            ],
            momentum=self.config.momentum,
            # dampening=self.config.dampening,
            weight_decay=self.config.weight_decay,
            # nesterov=self.config.nesterov
        )
        # dataloader
        self.dataloader = VOCDataLoader(self.args, self.config)
    def __init__(self, args, cuda=None):
        self.args = args
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.current_MIoU = 0
        self.best_MIou = 0
        self.current_epoch = 0
        self.current_iter = 0

        # set TensorboardX
        self.writer = SummaryWriter(log_dir=self.args.run_name)

        # Metric definition
        self.Eval = Eval(self.args.num_classes)

        # loss definition
        if self.args.loss_weight_file is not None:
            classes_weights_path = os.path.join(self.args.loss_weights_dir,
                                                self.args.loss_weight_file)
            print(classes_weights_path)
            if not os.path.isfile(classes_weights_path):
                logger.info('calculating class weights...')
                calculate_weigths_labels(self.args)
            class_weights = np.load(classes_weights_path)
            pprint.pprint(class_weights)
            weight = torch.from_numpy(class_weights.astype(np.float32))
            logger.info('loading class weights successfully!')
        else:
            weight = None

        self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
        self.loss.to(self.device)

        # model
        self.model = Unet_decoder(output_stride=self.args.output_stride,
                                  class_num=self.args.num_classes,
                                  pretrained=self.args.imagenet_pretrained
                                  and self.args.pretrained_ckpt_file == None,
                                  bn_momentum=self.args.bn_momentum,
                                  freeze_bn=self.args.freeze_bn)
        self.model = nn.DataParallel(self.model,
                                     device_ids=range(
                                         ceil(len(self.args.gpu) / 2)))
        patch_replication_callback(self.model)
        self.model.to(self.device)

        self.optimizer = torch.optim.SGD(
            params=[
                {
                    "params": self.get_params(self.model.module, key="1x"),
                    "lr": self.args.lr,
                },
                {
                    "params": self.get_params(self.model.module, key="10x"),
                    "lr": 10 * self.args.lr,
                },
            ],
            momentum=self.args.momentum,
            # dampening=self.args.dampening,
            weight_decay=self.args.weight_decay,
            # nesterov=self.args.nesterov
        )
        # dataloader
        self.dataloader = VOCDataLoader(self.args)
        self.epoch_num = ceil(self.args.iter_max /
                              self.dataloader.train_iterations)