def __init__(self, args, config, cuda=None):
        self.args = args
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu
        self.config = config
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.best_MIou = 0
        self.current_epoch = 0
        self.epoch_num = self.config.epoch_num
        self.current_iter = 0

        self.writer = SummaryWriter()

        # path definition
        self.val_list_filepath = os.path.join(
            args.data_root_path, 'VOC2012/ImageSets/Segmentation/val.txt')
        self.gt_filepath = os.path.join(args.data_root_path,
                                        'VOC2012/SegmentationClass/')
        self.pre_filepath = os.path.join(args.data_root_path,
                                         'VOC2012/JPEGImages/')

        # Metric definition
        self.Eval = Eval(self.config.num_classes)

        # loss definition
        if args.loss_weight:
            classes_weights_path = os.path.join(
                self.config.classes_weight,
                self.args.dataset + 'classes_weights_log.npy')
            print(classes_weights_path)
            if not os.path.isfile(classes_weights_path):
                logger.info('calculating class weights...')
                calculate_weigths_labels(self.config)
            class_weights = np.load(classes_weights_path)
            pprint.pprint(class_weights)
            weight = torch.from_numpy(class_weights.astype(np.float32))
            logger.info('loading class weights successfully!')
        else:
            weight = None

        self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
        self.loss.to(self.device)

        # model
        self.model = DeepLab(output_stride=self.args.output_stride,
                             class_num=self.config.num_classes,
                             pretrained=self.args.imagenet_pretrained,
                             bn_momentum=self.args.bn_momentum,
                             freeze_bn=self.args.freeze_bn)
        self.model = nn.DataParallel(self.model, device_ids=range(4))
        patch_replication_callback(self.model)
        self.model.to(self.device)

        self.optimizer = torch.optim.SGD(
            params=[
                {
                    "params": self.get_params(self.model.module, key="1x"),
                    "lr": self.args.lr,
                },
                {
                    "params": self.get_params(self.model.module, key="10x"),
                    "lr": 10 * self.args.lr,
                },
            ],
            momentum=self.config.momentum,
            # dampening=self.config.dampening,
            weight_decay=self.config.weight_decay,
            # nesterov=self.config.nesterov
        )
        # dataloader
        self.dataloader = VOCDataLoader(self.args, self.config)
class Trainer():
    def __init__(self, args, config, cuda=None):
        self.args = args
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu
        self.config = config
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.best_MIou = 0
        self.current_epoch = 0
        self.epoch_num = self.config.epoch_num
        self.current_iter = 0

        self.writer = SummaryWriter()

        # path definition
        self.val_list_filepath = os.path.join(
            args.data_root_path, 'VOC2012/ImageSets/Segmentation/val.txt')
        self.gt_filepath = os.path.join(args.data_root_path,
                                        'VOC2012/SegmentationClass/')
        self.pre_filepath = os.path.join(args.data_root_path,
                                         'VOC2012/JPEGImages/')

        # Metric definition
        self.Eval = Eval(self.config.num_classes)

        # loss definition
        if args.loss_weight:
            classes_weights_path = os.path.join(
                self.config.classes_weight,
                self.args.dataset + 'classes_weights_log.npy')
            print(classes_weights_path)
            if not os.path.isfile(classes_weights_path):
                logger.info('calculating class weights...')
                calculate_weigths_labels(self.config)
            class_weights = np.load(classes_weights_path)
            pprint.pprint(class_weights)
            weight = torch.from_numpy(class_weights.astype(np.float32))
            logger.info('loading class weights successfully!')
        else:
            weight = None

        self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
        self.loss.to(self.device)

        # model
        self.model = DeepLab(output_stride=self.args.output_stride,
                             class_num=self.config.num_classes,
                             pretrained=self.args.imagenet_pretrained,
                             bn_momentum=self.args.bn_momentum,
                             freeze_bn=self.args.freeze_bn)
        self.model = nn.DataParallel(self.model, device_ids=range(4))
        patch_replication_callback(self.model)
        self.model.to(self.device)

        self.optimizer = torch.optim.SGD(
            params=[
                {
                    "params": self.get_params(self.model.module, key="1x"),
                    "lr": self.args.lr,
                },
                {
                    "params": self.get_params(self.model.module, key="10x"),
                    "lr": 10 * self.args.lr,
                },
            ],
            momentum=self.config.momentum,
            # dampening=self.config.dampening,
            weight_decay=self.config.weight_decay,
            # nesterov=self.config.nesterov
        )
        # dataloader
        self.dataloader = VOCDataLoader(self.args, self.config)

    def main(self):
        # set TensorboardX

        # display config details
        logger.info("Global configuration as follows:")
        pprint.pprint(self.config)
        pprint.pprint(self.args)

        # choose cuda
        if self.cuda:
            # torch.cuda.set_device(4)
            current_device = torch.cuda.current_device()
            logger.info("This model will run on {}".format(
                torch.cuda.get_device_name(current_device)))
        else:
            logger.info("This model will run on CPU")

        # load pretrained checkpoint
        if self.args.pretrained:
            self.load_checkpoint(self.args.saved_checkpoint_file)

        # train
        self.train()

        self.writer.close()

    def train(self):
        for epoch in tqdm(range(self.current_epoch, self.epoch_num),
                          desc="Total {} epochs".format(
                              self.config.epoch_num)):
            self.current_epoch = epoch
            # self.scheduler.step(epoch)
            self.train_one_epoch()

            # validate

            PA, MPA, MIoU, FWIoU = self.validate()
            self.writer.add_scalar('PA', PA, self.current_epoch)
            self.writer.add_scalar('MPA', MPA, self.current_epoch)
            self.writer.add_scalar('MIoU', MIoU, self.current_epoch)
            self.writer.add_scalar('FWIoU', FWIoU, self.current_epoch)

            is_best = MIoU > self.best_MIou
            if is_best:
                self.best_MIou = MIoU
            self.save_checkpoint(is_best, self.args.store_checkpoint_name)

            # writer.add_scalar('PA', PA)
            # print(PA)

    def train_one_epoch(self):
        tqdm_epoch = tqdm(self.dataloader.train_loader,
                          total=self.dataloader.train_iterations,
                          desc="Train Epoch-{}-".format(self.current_epoch +
                                                        1))
        logger.info("Training one epoch...")
        self.Eval.reset()
        # Set the model to be in training mode (for batchnorm and dropout)

        train_loss = []
        preds = []
        lab = []
        self.model.train()
        # Initialize your average meters

        batch_idx = 0
        for x, y, _ in tqdm_epoch:
            self.poly_lr_scheduler(
                optimizer=self.optimizer,
                init_lr=self.args.lr,
                iter=self.current_iter,
                max_iter=self.args.iter_max,
                power=self.config.poly_power,
            )
            if self.current_iter >= self.args.iter_max:
                logger.info("iteration arrive {}!".format(self.args.iter_max))
                break
            self.writer.add_scalar('learning_rate',
                                   self.optimizer.param_groups[0]["lr"],
                                   self.current_iter)
            self.writer.add_scalar('learning_rate_10x',
                                   self.optimizer.param_groups[1]["lr"],
                                   self.current_iter)

            # y.to(torch.long)
            if self.cuda:
                x, y = x.to(self.device), y.to(device=self.device,
                                               dtype=torch.long)

            self.optimizer.zero_grad()

            # model
            pred = self.model(x)
            # logger.info("pre:{}".format(pred.data.cpu().numpy()))
            y = torch.squeeze(y, 1)
            # logger.info("y:{}".format(y.cpu().numpy()))
            # pred_s = F.softmax(pred, dim=1)
            # loss
            cur_loss = self.loss(pred, y)

            # optimizer

            cur_loss.backward()
            self.optimizer.step()

            train_loss.append(cur_loss.item())

            if batch_idx % self.config.batch_save == 0:
                logger.info("The train loss of epoch{}-batch-{}:{}".format(
                    self.current_epoch, batch_idx, cur_loss.item()))
            batch_idx += 1

            self.current_iter += 1

            # print(cur_loss)
            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during training...')

            pred = pred.data.cpu().numpy()
            label = y.cpu().numpy()
            argpred = np.argmax(pred, axis=1)
            self.Eval.add_batch(label, argpred)

        PA = self.Eval.Pixel_Accuracy()
        MPA = self.Eval.Mean_Pixel_Accuracy()
        MIoU = self.Eval.Mean_Intersection_over_Union()
        FWIoU = self.Eval.Frequency_Weighted_Intersection_over_Union()

        logger.info(
            'Epoch:{}, train PA1:{}, MPA1:{}, MIoU1:{}, FWIoU1:{}'.format(
                self.current_epoch, PA, MPA, MIoU, FWIoU))

        tr_loss = sum(train_loss) / len(train_loss)
        self.writer.add_scalar('train_loss', tr_loss, self.current_epoch)
        tqdm.write("The average loss of train epoch-{}-:{}".format(
            self.current_epoch, tr_loss))
        tqdm_epoch.close()

    def validate(self):
        logger.info('validating one epoch...')
        self.Eval.reset()
        with torch.no_grad():
            tqdm_batch = tqdm(self.dataloader.valid_loader,
                              total=self.dataloader.valid_iterations,
                              desc="Val Epoch-{}-".format(self.current_epoch +
                                                          1))
            val_loss = []
            preds = []
            lab = []
            self.model.eval()

            for x, y, id in tqdm_batch:
                # y.to(torch.long)
                if self.cuda:
                    x, y = x.to(self.device), y.to(device=self.device,
                                                   dtype=torch.long)

                # model
                pred = self.model(x)
                y = torch.squeeze(y, 1)

                cur_loss = self.loss(pred, y)
                if np.isnan(float(cur_loss.item())):
                    raise ValueError('Loss is nan during validating...')
                val_loss.append(cur_loss.item())

                # if self.args.store_result == True and self.current_epoch == 20:
                #     for i in range(len(id)):
                #         result = Image.fromarray(np.asarray(argpred, dtype=np.uint8)[i], mode='P')
                #         # logger.info("before:{}".format(result.mode))
                #         result = result.convert("RGB")
                #         # logger.info("after:{}".format(result.mode))
                #         # logger.info("shape:{}".format(result.getpixel((1,1))))
                #         result.save(self.args.result_filepath + id[i] + '.png')

                pred = pred.data.cpu().numpy()
                label = y.cpu().numpy()
                argpred = np.argmax(pred, axis=1)

                self.Eval.add_batch(label, argpred)

            PA = self.Eval.Pixel_Accuracy()
            MPA = self.Eval.Mean_Pixel_Accuracy()
            MIoU = self.Eval.Mean_Intersection_over_Union()
            FWIoU = self.Eval.Frequency_Weighted_Intersection_over_Union()

            logger.info(
                'Epoch:{}, validation PA1:{}, MPA1:{}, MIoU1:{}, FWIoU1:{}'.
                format(self.current_epoch, PA, MPA, MIoU, FWIoU))
            v_loss = sum(val_loss) / len(val_loss)
            logger.info("The average loss of val loss:{}".format(v_loss))
            self.writer.add_scalar('val_loss', v_loss, self.current_epoch)

            # logger.info(score)
            tqdm_batch.close()

        return PA, MPA, MIoU, FWIoU

    def save_checkpoint(self, is_best, filename=None):
        """
        Save checkpoint if a new best is achieved
        :param state:
        :param is_best:
        :param filepath:
        :return:
        """
        filename = os.path.join(self.args.checkpoint_dir, filename)
        state = {
            'epoch': self.current_epoch + 1,
            'iteration': self.current_iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_MIou': self.best_MIou
        }
        if is_best:
            logger.info("=>saving a new best checkpoint...")
            torch.save(state, filename)
        else:
            logger.info("=> The MIoU of val does't improve.")

    def load_checkpoint(self, filename):
        filename = os.path.join(self.args.checkpoint_dir, filename)
        try:
            logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            # self.current_epoch = checkpoint['epoch']
            # self.current_iter = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_MIou = checkpoint['best_MIou']

            logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {},MIoU:{})\n"
                .format(self.args.checkpoint_dir, checkpoint['epoch'],
                        checkpoint['iteration'], checkpoint['best_MIou']))
        except OSError as e:
            logger.info("No checkpoint exists from '{}'. Skipping...".format(
                self.args.checkpoint_dir))
            logger.info("**First time to train**")

    def get_params(self, model, key):
        # For Dilated CNN
        if key == "1x":
            for m in model.named_modules():
                if "Resnet101" in m[0]:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            yield p
        #
        if key == "10x":
            for m in model.named_modules():
                if "encoder" in m[0] or "decoder" in m[0]:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            yield p

    def poly_lr_scheduler(self, optimizer, init_lr, iter, max_iter, power):
        new_lr = init_lr * (1 - float(iter) / max_iter)**power
        optimizer.param_groups[0]["lr"] = new_lr
        optimizer.param_groups[1]["lr"] = 10 * new_lr
Esempio n. 3
0
class Trainer():
    def __init__(self, args, cuda=None):
        self.args = args
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.current_MIoU = 0
        self.best_MIou = 0
        self.current_epoch = 0
        self.current_iter = 0

        self.batch_idx = 0

        # set TensorboardX
        self.writer = SummaryWriter()

        # Metric definition
        self.Eval = Eval(self.args.num_classes)

        if self.args.loss == 'tanimoto':
            self.loss = tanimoto_loss()
        else:
            self.loss = nn.BCEWithLogitsLoss()

        self.loss.to(self.device)

        # model
        self.model = DeepLab(output_stride=self.args.output_stride,
                             class_num=self.args.num_classes,
                             num_input_channel=self.args.input_channels,
                             pretrained=self.args.imagenet_pretrained
                             and self.args.pretrained_ckpt_file is None,
                             bn_eps=self.args.bn_eps,
                             bn_momentum=self.args.bn_momentum,
                             freeze_bn=self.args.freeze_bn)

        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.m = self.model.module
        else:
            self.m = self.model
        self.model.to(self.device)

        self.optimizer = torch.optim.SGD(
            params=[
                {
                    "params": self.get_params(self.m, key="1x"),
                    "lr": self.args.lr,
                },
                {
                    "params": self.get_params(self.m, key="10x"),
                    "lr": 10 * self.args.lr,
                },
            ],
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay,
        )

        self.dataloader = ISICDataLoader(self.args)
        self.epoch_num = ceil(self.args.iter_max /
                              self.dataloader.train_iterations)

        if self.args.input_channels == 3:
            self.train_func = self.train_3ch
            if args.using_bb != 'none':
                if self.args.store_result:
                    self.validate_func = self.validate_crop_store_result
                else:
                    self.validate_func = self.validate_crop
            else:
                self.validate_func = self.validate_3ch
        else:
            self.train_func = self.train_4ch
            self.validate_func = self.validate_4ch

        if self.args.store_result:
            self.validate_one_epoch = self.validate_one_epoch_store_result

    def main(self):
        logger.info("Global configuration as follows:")
        for key, val in vars(self.args).items():
            logger.info("{:16} {}".format(key, val))

        if self.cuda:
            current_device = torch.cuda.current_device()
            logger.info("This model will run on {}".format(
                torch.cuda.get_device_name(current_device)))
        else:
            logger.info("This model will run on CPU")

        if self.args.pretrained_ckpt_file is not None:
            self.load_checkpoint(self.args.pretrained_ckpt_file)

        if self.args.validate:
            self.validate()
        else:
            self.train()

        self.writer.close()

    def train(self):
        for epoch in tqdm(range(self.current_epoch, self.epoch_num),
                          desc="Total {} epochs".format(self.epoch_num)):
            self.current_epoch = epoch
            tqdm_epoch = tqdm(
                self.dataloader.train_loader,
                total=self.dataloader.train_iterations,
                desc="Train Epoch-{}-".format(self.current_epoch + 1))
            logger.info("Training one epoch...")
            self.Eval.reset()

            self.train_loss = []
            self.model.train()
            if self.args.freeze_bn:
                for m in self.model.modules():
                    if isinstance(m, SynchronizedBatchNorm2d):
                        m.eval()

            # Initialize your average meters
            self.train_func(tqdm_epoch)

            MIoU_single_img, MIoU_thresh = self.Eval.Mean_Intersection_over_Union(
            )

            logger.info('Epoch:{}, train MIoU1:{}'.format(
                self.current_epoch, MIoU_thresh))
            tr_loss = sum(self.train_loss) / len(self.train_loss)
            self.writer.add_scalar('train_loss', tr_loss, self.current_epoch)
            tqdm.write("The average loss of train epoch-{}-:{}".format(
                self.current_epoch, tr_loss))
            tqdm_epoch.close()

            if self.current_epoch % 10 == 0:
                state = {
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_MIou': self.current_MIoU
                }
                # logger.info("=>saving the final checkpoint...")
                torch.save(state,
                           train_id + '_epoca_' + str(self.current_epoch))

            # validate
            if self.args.validation:
                MIoU, MIoU_thresh = self.validate()
                self.writer.add_scalar('MIoU', MIoU_thresh, self.current_epoch)

                self.current_MIoU = MIoU_thresh
                is_best = MIoU_thresh > self.best_MIou
                if is_best:
                    self.best_MIou = MIoU_thresh
                self.save_checkpoint(is_best, train_id + 'best.pth')

        state = {
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_MIou': self.current_MIoU
        }
        logger.info("=>saving the final checkpoint...")
        torch.save(state, train_id + 'final.pth')

    def train_3ch(self, tqdm_epoch):
        for x, y in tqdm_epoch:
            self.poly_lr_scheduler(
                optimizer=self.optimizer,
                init_lr=self.args.lr,
                iter=self.current_iter,
                max_iter=self.args.iter_max,
                power=self.args.poly_power,
            )
            if self.current_iter >= self.args.iter_max:
                logger.info("iteration arrive {}!".format(self.args.iter_max))
                break
            self.writer.add_scalar('learning_rate',
                                   self.optimizer.param_groups[0]["lr"],
                                   self.current_iter)
            self.writer.add_scalar('learning_rate_10x',
                                   self.optimizer.param_groups[1]["lr"],
                                   self.current_iter)
            self.train_one_epoch(x, y)

    def train_4ch(self, tqdm_epoch):
        for x, y, target in tqdm_epoch:
            self.poly_lr_scheduler(
                optimizer=self.optimizer,
                init_lr=self.args.lr,
                iter=self.current_iter,
                max_iter=self.args.iter_max,
                power=self.args.poly_power,
            )
            if self.current_iter >= self.args.iter_max:
                logger.info("iteration arrive {}!".format(self.args.iter_max))
                break
            self.writer.add_scalar('learning_rate',
                                   self.optimizer.param_groups[0]["lr"],
                                   self.current_iter)
            self.writer.add_scalar('learning_rate_10x',
                                   self.optimizer.param_groups[1]["lr"],
                                   self.current_iter)

            target = target.float()
            x = torch.cat((x, target), dim=1)
            self.train_one_epoch(x, y)

    def train_one_epoch(self, x, y):
        if self.cuda:
            x, y = x.to(self.device), y.to(device=self.device,
                                           dtype=torch.long)

        y[y > 0] = 1.
        self.optimizer.zero_grad()

        # model
        pred = self.model(x)

        y = torch.squeeze(y, 1)
        if self.args.num_classes == 1:
            y = y.to(device=self.device, dtype=torch.float)
            pred = pred.squeeze()
        # loss
        cur_loss = self.loss(pred, y)

        # optimizer
        cur_loss.backward()
        self.optimizer.step()

        self.train_loss.append(cur_loss.item())

        if self.batch_idx % 50 == 0:
            logger.info("The train loss of epoch{}-batch-{}:{}".format(
                self.current_epoch, self.batch_idx, cur_loss.item()))
        self.batch_idx += 1

        self.current_iter += 1

        # print(cur_loss)
        if np.isnan(float(cur_loss.item())):
            raise ValueError('Loss is nan during training...')

    def validate(self):
        logger.info('validating one epoch...')
        self.Eval.reset()
        self.iter = 0

        with torch.no_grad():
            tqdm_batch = tqdm(self.dataloader.valid_loader,
                              total=self.dataloader.valid_iterations,
                              desc="Val Epoch-{}-".format(self.current_epoch +
                                                          1))
            self.val_loss = []
            self.model.eval()
            self.validate_func(tqdm_batch)

            MIoU, MIoU_thresh = self.Eval.Mean_Intersection_over_Union()

            logger.info('validation MIoU1:{}'.format(MIoU))
            v_loss = sum(self.val_loss) / len(self.val_loss)
            print('Miou: ' + str(MIoU) + ' MIoU_thresh: ' + str(MIoU_thresh))

            self.writer.add_scalar('val_loss', v_loss, self.current_epoch)

            tqdm_batch.close()

        return MIoU, MIoU_thresh

    def validate_3ch(self, tqdm_batch):
        for x, y, w, h, name in tqdm_batch:
            self.validate_one_epoch(x, y, w, h, name)

    def validate_4ch(self, tqdm_batch):
        for x, y, target, w, h, name in tqdm_batch:
            target = target.float()
            x = torch.cat((x, target), dim=1)
            self.validate_one_epoch(x, y, w, h, name)

    def validate_crop(self, tqdm_batch):
        for i, (x, y, left, top, right, bottom, w, h,
                name) in enumerate(tqdm_batch):
            self.validate_one_epoch(x, y, w, h, name, left, top, right, bottom)

    def validate_crop_store_result(self, tqdm_batch):
        for i, (x, y, left, top, right, bottom, w, h,
                name) in enumerate(tqdm_batch):
            if self.cuda:
                x, y = x.to(self.device), y.to(device=self.device,
                                               dtype=torch.long)

            # model
            pred = self.model(x)
            if self.args.loss == 'tanimoto':
                pred = (pred - pred.min()) / (pred.max() - pred.min())
            else:
                pred = nn.Sigmoid()(pred)

            pred = pred.squeeze().data.cpu().numpy()
            for i, single_argpred in enumerate(pred):
                pil = Image.fromarray(single_argpred)
                pil = pil.resize((right[i] - left[i], bottom[i] - top[i]))
                img = np.array(pil)
                img_border = cv.copyMakeBorder(img,
                                               top[i].numpy(),
                                               h[i].numpy() -
                                               bottom[i].numpy(),
                                               left[i].numpy(),
                                               w[i].numpy() - right[i].numpy(),
                                               cv.BORDER_CONSTANT,
                                               value=[0, 0, 0])

                if self.args.store_result:
                    img_border *= 255
                    pil = Image.fromarray(img_border.astype('uint8'))
                    pil.save(args.result_filepath +
                             'ISIC_{}.png'.format(name[i]))

                    self.iter += 1

    def validate_one_epoch_store_result(self, x, y, w, h, name):
        if self.cuda:
            x, y = x.to(self.device), y.to(device=self.device,
                                           dtype=torch.long)

        # model
        pred = self.model(x)
        if self.args.loss == 'tanimoto':
            pred = (pred - pred.min()) / (pred.max() - pred.min())
        else:
            pred = nn.Sigmoid()(pred)

        pred = pred.squeeze().data.cpu().numpy()
        for i, single_argpred in enumerate(pred):
            pil = Image.fromarray(single_argpred)
            pil = pil.resize((w[i], h[i]))
            img_border = np.array(pil)
            if self.args.store_result:
                img_border *= 255
                pil = Image.fromarray(img_border.astype('uint8'))
                pil.save(args.result_filepath + 'ISIC_{}.png'.format(name[i]))

                self.iter += 1

    # def validate_crop(self, tqdm_batch):
    #     for i, (x, y, left, top, right, bottom, w, h, name) in enumerate(tqdm_batch):
    #         if self.cuda:
    #             x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long)
    #
    #         pred = self.model(x)
    #         y = torch.squeeze(y, 1)
    #         if self.args.num_classes == 1:
    #             y = y.to(device=self.device, dtype=torch.float)
    #             pred = pred.squeeze()
    #
    #         cur_loss = self.loss(pred, y)
    #         if np.isnan(float(cur_loss.item())):
    #             raise ValueError('Loss is nan during validating...')
    #         self.val_loss.append(cur_loss.item())
    #
    #         pred = pred.data.cpu().numpy()
    #
    #         pred[pred >= 0.5] = 1
    #         pred[pred < 0.5] = 0
    #         print('\n')
    #         for i, single_pred in enumerate(pred):
    #             gt = Image.open(self.args.data_root_path + "ground_truth/ISIC_" + name[i] + "_segmentation.png")
    #             pil = Image.fromarray(single_pred.astype('uint8'))
    #             pil = pil.resize((right[i] - left[i], bottom[i] - top[i]))
    #             img = np.array(pil)
    #             ground_border = np.array(gt)
    #             ground_border[ground_border == 255] = 1
    #             img_border = cv.copyMakeBorder(img, top[i].numpy(), h[i].numpy() - bottom[i].numpy(),
    #                                            left[i].numpy(),
    #                                            w[i].numpy() - right[i].numpy(), cv.BORDER_CONSTANT, value=[0, 0, 0])
    #
    #             iou = self.Eval.iou_numpy(img_border, ground_border)
    #             print(name[i] + ' iou: ' + str(iou))
    #
    #             if self.args.store_result:
    #                 img_border[img_border == 1] = 255
    #                 pil = Image.fromarray(img_border)
    #                 pil.save(args.result_filepath + 'ISIC_{}.png'.format(name[i]))
    #                 # gt.save(args.result_filepath + 'ISIC_ground_{}.png'.format(name[i]))
    #
    #                 self.iter += 1

    def validate_one_epoch(self, x, y, w, h, name, *ltrb):
        if self.cuda:
            x, y = x.to(self.device), y.to(device=self.device,
                                           dtype=torch.long)

        # model
        pred = self.model(x)
        y = torch.squeeze(y, 1)
        if self.args.num_classes == 1:
            y = y.to(device=self.device, dtype=torch.float)
            pred = pred.squeeze()

        cur_loss = self.loss(pred, y)
        if np.isnan(float(cur_loss.item())):
            raise ValueError('Loss is nan during validating...')
        self.val_loss.append(cur_loss.item())

        pred = pred.data.cpu().numpy()

        pred[pred >= 0.5] = 1
        pred[pred < 0.5] = 0
        print('\n')
        for i, single_pred in enumerate(pred):
            gt = Image.open(self.args.data_root_path + "ground_truth/ISIC_" +
                            name[i] + "_segmentation.png")
            pil = Image.fromarray(single_pred.astype('uint8'))

            if self.args.using_bb and self.args.input_channels == 3:
                pil = pil.resize(
                    (ltrb[2][i] - ltrb[0][i], ltrb[3][i] - ltrb[1][i]))
                img = np.array(pil)
                img_border = cv.copyMakeBorder(
                    img,
                    ltrb[1][i].numpy(),
                    h[i].numpy() - ltrb[3][i].numpy(),
                    ltrb[0][i].numpy(),
                    w[i].numpy() - ltrb[2][i].numpy(),
                    cv.BORDER_CONSTANT,
                    value=[0, 0, 0])
            else:
                pil = pil.resize((w[i], h[i]))
                img_border = np.array(pil)

            ground_border = np.array(gt)
            ground_border[ground_border == 255] = 1
            iou = self.Eval.IoU_one_class(img_border, ground_border)

            print(name[i] + ' iou: ' + str(iou))

            if self.args.store_result:
                img_border[img_border == 1] = 255
                pil = Image.fromarray(img_border)
                pil.save(args.result_filepath + 'ISIC_{}.png'.format(name[i]))
                # gt.save(args.result_filepath + 'ISIC_ground_{}.png'.format(name[i]))

                self.iter += 1

    def save_checkpoint(self, is_best, filename=None):
        filename = os.path.join(self.args.checkpoint_dir, filename)
        state = {
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_MIou': self.best_MIou
        }
        if is_best:
            logger.info("=>saving a new best checkpoint...")
            torch.save(state, filename)
        else:
            logger.info("=> The MIoU of val does't improve.")

    def load_checkpoint(self, filename):
        try:
            logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)
            if 'module.Resnet101.bn1.weight' in checkpoint['state_dict']:
                checkpoint2 = collections.OrderedDict([
                    (k[7:], v) for k, v in checkpoint['state_dict'].items()
                ])
                self.model.load_state_dict(checkpoint2)
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

            if not self.args.freeze_bn:
                self.current_epoch = checkpoint['epoch']
                self.current_iter = checkpoint['iteration']
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_MIou = checkpoint['best_MIou']
            print(
                "Checkpoint loaded successfully from '{}', MIoU:{})\n".format(
                    self.args.checkpoint_dir, checkpoint['best_MIou']))
            logger.info(
                "Checkpoint loaded successfully from '{}', MIoU:{})\n".format(
                    self.args.checkpoint_dir, checkpoint['best_MIou']))
        except OSError as e:
            logger.info("No checkpoint exists from '{}'. Skipping...".format(
                self.args.checkpoint_dir))
            logger.info("**First time to train**")

    def get_params(self, model, key):
        # For Dilated CNN
        if key == "1x":
            for m in model.named_modules():
                if "Resnet101" in m[0]:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            yield p
        #
        if key == "10x":
            for m in model.named_modules():
                if "encoder" in m[0] or "decoder" in m[0]:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            yield p

    def poly_lr_scheduler(self, optimizer, init_lr, iter, max_iter, power):
        new_lr = init_lr * (1 - float(iter) / max_iter)**power
        optimizer.param_groups[0]["lr"] = new_lr
        optimizer.param_groups[1]["lr"] = 10 * new_lr
Esempio n. 4
0
    def __init__(self, args, cuda=None):
        self.args = args
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.current_MIoU = 0
        self.best_MIou = 0
        self.current_epoch = 0
        self.current_iter = 0

        self.batch_idx = 0

        # set TensorboardX
        self.writer = SummaryWriter()

        # Metric definition
        self.Eval = Eval(self.args.num_classes)

        if self.args.loss == 'tanimoto':
            self.loss = tanimoto_loss()
        else:
            self.loss = nn.BCEWithLogitsLoss()

        self.loss.to(self.device)

        # model
        self.model = DeepLab(output_stride=self.args.output_stride,
                             class_num=self.args.num_classes,
                             num_input_channel=self.args.input_channels,
                             pretrained=self.args.imagenet_pretrained
                             and self.args.pretrained_ckpt_file is None,
                             bn_eps=self.args.bn_eps,
                             bn_momentum=self.args.bn_momentum,
                             freeze_bn=self.args.freeze_bn)

        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.m = self.model.module
        else:
            self.m = self.model
        self.model.to(self.device)

        self.optimizer = torch.optim.SGD(
            params=[
                {
                    "params": self.get_params(self.m, key="1x"),
                    "lr": self.args.lr,
                },
                {
                    "params": self.get_params(self.m, key="10x"),
                    "lr": 10 * self.args.lr,
                },
            ],
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay,
        )

        self.dataloader = ISICDataLoader(self.args)
        self.epoch_num = ceil(self.args.iter_max /
                              self.dataloader.train_iterations)

        if self.args.input_channels == 3:
            self.train_func = self.train_3ch
            if args.using_bb != 'none':
                if self.args.store_result:
                    self.validate_func = self.validate_crop_store_result
                else:
                    self.validate_func = self.validate_crop
            else:
                self.validate_func = self.validate_3ch
        else:
            self.train_func = self.train_4ch
            self.validate_func = self.validate_4ch

        if self.args.store_result:
            self.validate_one_epoch = self.validate_one_epoch_store_result
    def __init__(self, args, cuda=None):
        self.args = args
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu
        self.cuda = cuda and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')

        self.current_MIoU = 0
        self.best_MIou = 0
        self.current_epoch = 0
        self.current_iter = 0

        # set TensorboardX
        self.writer = SummaryWriter(log_dir=self.args.run_name)

        # Metric definition
        self.Eval = Eval(self.args.num_classes)

        # loss definition
        if self.args.loss_weight_file is not None:
            classes_weights_path = os.path.join(self.args.loss_weights_dir, self.args.loss_weight_file)
            print(classes_weights_path)
            if not os.path.isfile(classes_weights_path):
                logger.info('calculating class weights...')
                calculate_weigths_labels(self.args)
            class_weights = np.load(classes_weights_path)
            pprint.pprint(class_weights)
            weight = torch.from_numpy(class_weights.astype(np.float32))
            logger.info('loading class weights successfully!')
        else:
            weight = None

        self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
        self.loss.to(self.device)

        # model
        self.model = DeepLab(output_stride=self.args.output_stride,
                             class_num=self.args.num_classes,
                             pretrained=self.args.imagenet_pretrained and self.args.pretrained_ckpt_file==None,
                             bn_momentum=self.args.bn_momentum,
                             freeze_bn=self.args.freeze_bn)
        self.model = nn.DataParallel(self.model, device_ids=range(ceil(len(self.args.gpu)/2)))
        patch_replication_callback(self.model)
        self.model.to(self.device)

        self.optimizer = torch.optim.SGD(
            params=[
                {
                    "params": self.get_params(self.model.module, key="1x"),
                    "lr": self.args.lr,
                },
                {
                    "params": self.get_params(self.model.module, key="10x"),
                    "lr": 10 * self.args.lr,
                },
            ],
            momentum=self.args.momentum,
            # dampening=self.args.dampening,
            weight_decay=self.args.weight_decay,
            # nesterov=self.args.nesterov
        )
        # dataloader
        self.dataloader = VOCDataLoader(self.args)
        self.epoch_num = ceil(self.args.iter_max / self.dataloader.train_iterations)