Exemple #1
0
class Tester(object):
    def __init__(self, args):
        if not os.path.isfile(args.model):
            raise RuntimeError("no checkpoint found at '{}'".fromat(args.model))
        self.args = args
        self.color_map = get_pascal_labels()
        self.test_loader, self.ids, self.nclass = make_data_loader(args)

        #Define model
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=False,
                        freeze_bn=False)
        
        self.model = model
        device = torch.device('cpu')
        checkpoint = torch.load(args.model, map_location=device)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.evaluator = Evaluator(self.nclass)

    def save_image(self, array, id, op):
        text = 'gt'
        if op == 0:
            text = 'pred'
        file_name = str(id)+'_'+text+'.png'
        r = array.copy()
        g = array.copy()
        b = array.copy()

        for i in range(self.nclass):
            r[array == i] = self.color_map[i][0]
            g[array == i] = self.color_map[i][1]
            b[array == i] = self.color_map[i][2]
    
        rgb = np.dstack((r, g, b))

        save_img = Image.fromarray(rgb.astype('uint8'))
        save_img.save(self.args.save_path+os.sep+file_name)


    def test(self):
        self.model.eval()
        self.evaluator.reset()
        # tbar = tqdm(self.test_loader, desc='\r')
        for i, sample in enumerate(self.test_loader):
            image, target = sample['image'], sample['label']
            with torch.no_grad():
                output = self.model(image)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.save_image(pred[0], self.ids[i], 0)
            self.save_image(target[0], self.ids[i], 1)
            self.evaluator.add_batch(target, pred)
    
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        print('Acc:{}, Acc_class:{}'.format(Acc, Acc_class))
def validation(epoch, model, args, criterion, nclass, test_tag=False):
    model.eval()

    losses = 0.0

    evaluator = Evaluator(nclass)
    evaluator.reset()
    if test_tag == True:
        num_img = args.data_dict['num_valid']
    else:
        num_img = args.data_dict['num_test']
    for i in range(num_img):
        if test_tag == True:
            inputs = torch.FloatTensor(args.data_dict['valid_data'][i]).cuda()
            target = torch.FloatTensor(args.data_dict['valid_mask'][i]).cuda()
        else:
            inputs = torch.FloatTensor(args.data_dict['test_data'][i]).cuda()
            target = torch.FloatTensor(args.data_dict['test_mask'][i]).cuda()

        with torch.no_grad():
            output = model(inputs)
        loss_val = criterion(output, target)
        print('epoch: {0}\t'
              'iter: {1}/{2}\t'
              'loss: {loss:.4f}'.format(epoch + 1,
                                        i + 1,
                                        args.data_dict['num_train'],
                                        loss=loss_val))
        pred = output.data.cpu().numpy()
        target = target.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        evaluator.add_batch(target, pred)

        losses += loss_val

        if test_tag == True:
            #save input,target,pred
            pred_save_dir = './pred/'
            sitk.WriteImage(sitk.GetImageFromArray(inputs),
                            pred_save_dir + 'input_{}.nii.gz'.format(i))
            sitk.WriteImage(sitk.GetImageFromArray(target),
                            pred_save_dir + 'target_{}.nii.gz'.format(i))
            sitk.WriteImage(
                sitk.GetImageFromArray(pred),
                pred_save_dir + 'pred_{}_{}.nii.gz'.format(i, epoch))

    Acc = evaluator.Pixel_Accuracy()
    Acc_class = evaluator.Pixel_Accuracy_Class()
    mIoU = evaluator.Mean_Intersection_over_Union()
    FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
    if test_tag == True:
        print('Test:')
    else:
        print('Validation:')
    print('[Epoch: %d, numImages: %5d]' % (epoch, num_img))
    print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
        Acc, Acc_class, mIoU, FWIoU))
    print('Loss: %.3f' % losses)
Exemple #3
0
def forward_all(net_inference, dataloader, visualize=False, opt=None):
    evaluator = Evaluator(21)
    evaluator.reset()
    with torch.no_grad():
        for ii, sample in enumerate(dataloader):
            image, label = sample['image'].cuda(), sample['label'].cuda()

            activations = net_inference(image)

            image = image.cpu().numpy()
            label = label.cpu().numpy().astype(np.uint8)

            logits = activations[list(activations.keys(
            ))[-1]] if type(activations) != torch.Tensor else activations
            pred = torch.max(logits, 1)[1].cpu().numpy().astype(np.uint8)

            evaluator.add_batch(label, pred)

            # print(label.shape, pred.shape)
            if visualize:
                for jj in range(sample["image"].size()[0]):
                    segmap_label = decode_segmap(label[jj], dataset='pascal')
                    segmap_pred = decode_segmap(pred[jj], dataset='pascal')

                    img_tmp = np.transpose(image[jj], axes=[1, 2, 0])
                    img_tmp *= (0.229, 0.224, 0.225)
                    img_tmp += (0.485, 0.456, 0.406)
                    img_tmp *= 255.0
                    img_tmp = img_tmp.astype(np.uint8)

                    cv2.imshow('image', img_tmp[:, :, [2, 1, 0]])
                    cv2.imshow('gt', segmap_label)
                    cv2.imshow('pred', segmap_pred)
                    cv2.waitKey(0)

    Acc = evaluator.Pixel_Accuracy()
    Acc_class = evaluator.Pixel_Accuracy_Class()
    mIoU = evaluator.Mean_Intersection_over_Union()
    FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
    print("Acc: {}".format(Acc))
    print("Acc_class: {}".format(Acc_class))
    print("mIoU: {}".format(mIoU))
    print("FWIoU: {}".format(FWIoU))
    if opt is not None:
        with open("seg_result.txt", 'a+') as ww:
            ww.write(
                "{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n"
                .format(opt.dataset, opt.quantize, opt.relu, opt.equalize,
                        opt.absorption, opt.correction, opt.clip_weight,
                        opt.distill_range))
            ww.write("Acc: {}, Acc_class: {}, mIoU: {}, FWIoU: {}\n\n".format(
                Acc, Acc_class, mIoU, FWIoU))
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        if args.dataset == 'click':
            extract_hard_example(args, batch_size=32, recal=False)
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        sbox = DeepLabX(pretrain=False)
        sbox.load_state_dict(
            torch.load('run/sbox_513_8925.pth.tar',
                       map_location=torch.device('cuda:0'))['state_dict'])
        click = ClickNet()
        model = FusionNet(sbox=sbox, click=click, pos_limit=2, neg_limit=2)
        model.sbox_net.eval()
        for para in model.sbox_net.parameters():
            para.requires_grad = False

        train_params = [
            {
                'params': model.click_net.parameters(),
                'lr': args.lr
            },
            # {'params': model.sbox_net.get_1x_lr_params(), 'lr': args.lr*0.001}
            # {'params': model.sbox_net.get_train_click_params(), 'lr': args.lr*0.001}
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        self.model.sbox_net.eval()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, gt = sample['crop_image'], sample['crop_gt']
            if self.args.cuda:
                image, gt = image.cuda(), gt.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            sbox_pred, click_pred, sum_pred = self.model(image, crop_gt=gt)
            sum_pred = F.interpolate(sum_pred,
                                     size=gt.size()[-2:],
                                     align_corners=True,
                                     mode='bilinear')
            sbox_pred = F.interpolate(sbox_pred,
                                      size=gt.size()[-2:],
                                      align_corners=True,
                                      mode='bilinear')
            loss1 = self.criterion(sum_pred, gt) \
                # + self.criterion(sbox_pred, gt)

            loss1.backward()
            self.optimizer.step()
            total_loss = loss1.item()
            train_loss += total_loss
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_steps', total_loss,
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                grid_image = make_grid(decode_seg_map_sequence(
                    torch.max(sbox_pred[:3], 1)[1].detach().cpu().numpy(),
                    dataset=self.args.dataset),
                                       3,
                                       normalize=False,
                                       range=(0, 255))
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, sample['crop_gt'],
                                             sum_pred, global_step)
                self.writer.add_image('sbox_pred', grid_image, global_step)

        self.writer.add_scalar('train/total_epochs', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        total_clicks = 0
        for i, sample in enumerate(tbar):
            image, gt = sample['crop_image'], sample['crop_gt']
            if self.args.cuda:
                image, gt = image.cuda(), gt.cuda()
            with torch.no_grad():
                sbox_pred, click_pred, sum_pred = self.model(image, crop_gt=gt)
                # sum_pred, clicks = self.model.click_eval(image, gt)
            # total_clicks += clicks
            sum_pred = F.interpolate(sum_pred,
                                     size=gt.size()[-2:],
                                     align_corners=True,
                                     mode='bilinear')
            loss1 = self.criterion(sum_pred, gt)
            total_loss = loss1.item()
            test_loss += total_loss
            pred = sum_pred.data.cpu().numpy()
            target = gt.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_epochs', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        # print('total clicks:' , total_clicks)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                },
                is_best,
                prefix='click')
Exemple #5
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.train_dir = './data_list/train_lite.csv'
        self.train_list = pd.read_csv(self.train_dir)
        self.val_dir = './data_list/val_lite.csv'
        self.val_list = pd.read_csv(self.val_dir)
        self.train_length = len(self.train_list)
        self.val_length = len(self.val_list)
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # 方式2
        self.train_gen, self.val_gen, self.test_gen, self.nclass = make_data_loader2(args)
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)
        # optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)

        # Define Criterion
        # self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion1 = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='ce')
        self.criterion2= SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='dice')

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, self.train_length)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                # self.model.module.load_state_dict(checkpoint['state_dict'])
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        prev_time = time.time()
        self.model.train()
        self.evaluator.reset()

        num_img_tr = self.train_length / self.args.batch_size

        for iteration in range(int(num_img_tr)):
            samples = next(self.train_gen)
            image, target = samples['image'], samples['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, iteration, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, make_one_hot(target.long(), num_classes=self.nclass))
            loss = loss1 + loss2
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            self.writer.add_scalar('train/total_loss_iter', loss.item(), iteration + num_img_tr * epoch)


            # print log  默认log_iters = 4
            if iteration % 4 == 0:
                end_time = time.time()
                print("Iter - %d: train loss: %.3f, celoss: %.4f, diceloss: %.4f, time cost: %.3f s" \
                      % (iteration, loss.item(), loss1.item(), loss2.item(), end_time - prev_time))
                prev_time = time.time()

            # Show 10 * 3 inference results each epoch
            if iteration % (num_img_tr // 10) == 0:
                global_step = iteration + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        print("input image shape/iter:", image.shape)

        # train evaluate
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU = np.nanmean(IoU)
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc_tr:{}, Acc_class_tr:{}, IoU_tr:{}, mIoU_tr:{}, fwIoU_tr: {}".format(Acc, Acc_class, IoU, mIoU, FWIoU))

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)





    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        val_loss = 0.0
        prev_time = time.time()
        num_img_val = self.val_length / self.args.batch_size
        print("Validation:","epoch ", epoch)
        print(num_img_val)
        for iteration in range(int(num_img_val)):
            samples = next(self.val_gen)
            image, target = samples['image'], samples['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():  #
                output = self.model(image)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, make_one_hot(target.long(), num_classes=self.nclass))
            loss = loss1 + loss2
            val_loss += loss.item()
            self.writer.add_scalar('val/total_loss_iter', loss.item(), iteration + num_img_val * epoch)
            val_loss += loss.item()

            # print log  默认log_iters = 4
            if iteration % 4 == 0:
                end_time = time.time()
                print("Iter - %d: validation loss: %.3f, celoss: %.4f, diceloss: %.4f, time cost: %.3f s" \
                      % (iteration, loss.item(), loss1.item(), loss2.item(), end_time - prev_time))
                prev_time = time.time()


            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        print(image.shape)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU = np.nanmean(IoU)
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', val_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print("Acc_val:{}, Acc_class_val:{}, IoU:val:{}, mIoU_val:{}, fwIoU_val: {}".format(Acc, Acc_class, IoU, mIoU, FWIoU))
        print('Loss: %.3f' % val_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Exemple #6
0
class trainNew(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        cell_path = os.path.join(args.saved_arch_path, 'genotype.npy')
        network_path_space = os.path.join(args.saved_arch_path,
                                          'network_path_space.npy')

        new_cell_arch = np.load(cell_path)
        new_network_arch = np.load(network_path_space)

        # Define network
        model = newModel(network_arch=new_network_arch,
                         cell_arch=new_cell_arch,
                         num_classes=self.nclass,
                         num_layers=12)
        #                        output_stride=args.out_stride,
        #                        sync_bn=args.sync_bn,
        #                        freeze_bn=args.freeze_bn)
        self.decoder = Decoder(self.nclass, 'autodeeplab', args, False)
        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
        #        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        #
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(
            args.lr_scheduler, args.lr, args.epochs,
            len(self.train_loader))  #TODO: use min_lr ?

        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            encoder_output, low_level_feature = self.model(image)
            output = self.decoder(encoder_output, low_level_feature)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                encoder_output, low_level_feature = self.model(image)
                output = self.decoder(encoder_output, low_level_feature)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Exemple #7
0
def main():
    args = parse_args()

    if args.dataset == 'CamVid':
        num_class = 32
    elif args.dataset == 'Cityscapes':
        num_class = 19

    if args.net == 'resnet101':
        blocks = [2, 4, 23, 3]
        model = FPN(blocks, num_class, back_bone=args.net)

    if args.checkname is None:
        args.checkname = 'fpn-' + str(args.net)

    evaluator = Evaluator(num_class)

    # Trained model path and name
    experiment_dir = args.experiment_dir
    load_name = os.path.join(experiment_dir, 'checkpoint.pth.tar')

    # Load trained model
    if not os.path.isfile(load_name):
        raise RuntimeError("=> no checkpoint found at '{}'".format(load_name))
    print('====>loading trained model from ' + load_name)
    checkpoint = torch.load(load_name)
    checkepoch = checkpoint['epoch']
    if args.cuda:
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint['state_dict'])

    # Load image and save in test_imgs
    test_imgs = []
    test_label = []
    if args.dataset == "CamVid":
        root_dir = Path.db_root_dir('CamVid')
        test_file = os.path.join(root_dir, "val.csv")
        test_data = CamVidDataset(csv_file=test_file, phase='val')
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    elif args.dataset == "Cityscapes":
        kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
        #_, test_loader, _, _ = make_data_loader(args, **kwargs)
        _, val_loader, test_loader, _ = make_data_loader(args, **kwargs)
    else:
        raise RuntimeError("dataset {} not found.".format(args.dataset))

    # test
    Acc = []
    Acc_class = []
    mIoU = []
    FWIoU = []
    results = []
    for iter, batch in enumerate(val_loader):
        if args.dataset == 'CamVid':
            image, target = batch['X'], batch['l']
        elif args.dataset == 'Cityscapes':
            image, target = batch['image'], batch['label']
        else:
            raise NotImplementedError

        if args.cuda:
            image, target, model = image.cuda(), target.cuda(), model.cuda()
        with torch.no_grad():
            output = model(image)
        pred = output.data.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        target = target.cpu().numpy()
        evaluator.add_batch(target, pred)

        # show result
        pred_rgb = decode_seg_map_sequence(pred, args.dataset, args.plot)
        results.append(pred_rgb)

    Acc = evaluator.Pixel_Accuracy()
    Acc_class = evaluator.Pixel_Accuracy_Class()
    mIoU = evaluator.Mean_Intersection_over_Union()
    FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()

    print('Mean evaluate result on dataset {}'.format(args.dataset))
    print('Acc:{:.3f}\tAcc_class:{:.3f}\nmIoU:{:.3f}\tFWIoU:{:.3f}'.format(Acc, Acc_class, mIoU, FWIoU))
Exemple #8
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset)[0], args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if os.path.isfile(args.resume):
                checkpoint = torch.load(args.resume,map_location=torch.device('cpu'))
                args.start_epoch = checkpoint['epoch']
                if args.cuda:
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])
                if not args.ft:
                    self.optimizer.load_state_dict(checkpoint['optimizer'])
                # self.best_pred = checkpoint['best_pred']-0.3
                print("=> loaded checkpoint '{}' (epoch {})"
                    .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target,weight = sample['image'], sample['label'],sample['weight']
            if self.args.cuda:
                image, target,weight= image.cuda(), target.cuda(),weight.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = 0
            for index in range(output.shape[0]):
                temp1 = output[index].unsqueeze(0)
                temp2 = target[index].unsqueeze(0)
                loss = loss + weight[index,0,0]*self.criterion(temp1,temp2)
            loss.backward() 
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
        # self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)



    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target= sample['image'], sample['label']#, sample['weight']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)


            # for channels in range(target.shape[0]):
            #     imagex = image[channels].cpu().numpy()
            #     imagex = np.transpose(imagex,(1,2,0))
            #     pre = pred[channels]
            #     targ = target[channels]

            #     plt.subplot(131)
            #     plt.imshow(imagex)

            #     plt.subplot(132)
            #     image1 = imagex.copy()
            #     for i in [0,1] :
            #         g = image1[:,:,i]
            #         g[pre>0.5] = 255
            #         image1[:,:,i] = g
            #     for i in [2]:
            #         g = image1[:,:,i]
            #         g[pre>0.5] = 0
            #         image1[:,:,i] = g
            #     plt.imshow(image1)

            #     plt.subplot(133)
            #     image2 = imagex.copy()
            #     for i in [0,1] :
            #         g = image2[:,:,i]
            #         g[targ>0.5] = 255
            #         image2[:,:,i] = g
            #     for i in [2]:
            #         g = image2[:,:,i]
            #         g[targ>0.5] = 0
            #         image2[:,:,i] = g
            #     plt.imshow(image2)

            #     plt.show()


            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        xy_mIoU = self.evaluator.xy_Mean_Intersection_over_Union()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.test_batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print("min_mIoU{}".format(xy_mIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = xy_mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class PrunningFineTuner_DEEPLAB:
    def __init__(self, backbone, num_classes, train_data_loader, test_data_loader, use_cuda , model):

        self.train_data_loader = train_data_loader
        self.test_data_loader = test_data_loader

        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
        self.prunner = FilterPrunner(self.model,use_cuda)
        self.model.train()
        self.evaluator = Evaluator(num_classes)
        self.use_cuda = use_cuda



    def test(self):
        # return
        print('Testing model:')
        self.model.eval()
        self.evaluator.reset()
        mean_infer_time = []


        for i, sample in enumerate(self.test_data_loader):
            if i % 50 == 0:
                print('Processing batch {}/{}'.format(i,int(len(self.test_data_loader))))
            t0 = time.time()
            batch, target = sample['image'], sample['label']
            if self.use_cuda:
                batch = batch.cuda()
            input = Variable(batch)
            output = self.model(input)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            t_total = (time.time() - t0)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            mean_infer_time.append(t_total)

        m_time = np.mean(mean_infer_time)
        print("Mean inference after pruning took {} ms per epoch.".format(m_time*1000))

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc:{0:.3f}, Acc_class:{1:.3f}, mIoU:{2:.3f}, fwIoU: {3:.3f}".format(Acc*100, Acc_class*100, mIoU*100, FWIoU*100))

        self.model.train()

        return m_time, Acc, Acc_class, mIoU, FWIoU

    def train(self, optimizer=None, epoches=10):
        if optimizer is None:
            optimizer = optim.SGD(self.model.classifier.parameters(), lr=0.0001, momentum=0.9)


        for i in range(epoches):
            print("Epoch: ", i)
            self.train_epoch(optimizer)
            # self.test()
        print("Finished fine tuning.")

        return self.test()

    def train_batch(self, optimizer, batch, label, rank_filters):

        if self.use_cuda:
            batch = batch.cuda()
            label = label.cuda()

        self.model.zero_grad()
        input = Variable(batch)

        if rank_filters:
            output = self.prunner.forward(input)
            loss = self.criterion(output, label.long())
            loss.backward()
        else:
            output = self.model(input)
            loss = self.criterion(output, label.long())
            loss.backward()
            optimizer.step()

    def train_epoch(self, optimizer=None, rank_filters=False):

        if optimizer is None:
            optimizer = optim.SGD(self.model.parameters(), lr=0.0001, momentum=0.9)

        for i, sample in enumerate(self.train_data_loader):
            if i % 100 == 0:
                print('Processing batch {}/{}'.format(i,int(len(self.train_data_loader))))
            # if i>20:
            #     break;
            batch, label = sample['image'], sample['label']
            self.train_batch(optimizer, batch, label, rank_filters)

    def get_candidates_to_prune(self, num_filters_to_prune):
        self.prunner.reset()
        self.train_epoch(rank_filters=True)
        self.prunner.normalize_ranks_per_layer()
        return self.prunner.get_prunning_plan(num_filters_to_prune)

    def total_num_filters(self):

        filters = 0

        def count_conv_layers(network,filters=0):

            for layer in network.children():
                # print(type(layer))
                if isinstance(layer, torch.nn.modules.conv.Conv2d):
                    filters = filters + layer.out_channels
                elif 'Block' in str(type(layer)):  # if sequential layer, apply recursively to layers in sequential layer
                    filters += count_conv_layers(layer)
                elif 'SeparableConv2d' in str(type(layer)):  # if sequential layer, apply recursively to layers in sequential layer
                    filters += count_conv_layers(layer)
                elif type(layer) == nn.Sequential:
                    filters += count_conv_layers(layer)

            return filters

        # for name, module in self.model.backbone._modules.items():
        #     if isinstance(module, torch.nn.modules.conv.Conv2d):
        #         filters = filters + module.out_channels
        #     elif 'block' in name:
        #         for block_name, block_module in module._modules.items():
        #             if isinstance(block_module, torch.nn.modules.conv.Conv2d):
        #                 filters = filters + block_module.out_channels
        return count_conv_layers(self.model.backbone,filters)

    def prune(self):
        # Get the accuracy before prunning
        self.test()
        self.model.train()

        epoch_times = []
        # Make sure all the layers are trainable
        for param in self.model.backbone.parameters():
            param.requires_grad = True

        number_of_filters = self.total_num_filters()
        num_filters_to_prune_per_iteration = 256
        iterations = int(float(number_of_filters) / num_filters_to_prune_per_iteration)

        iterations = int(iterations * 2.0 / 4)

        print("Number of prunning iterations to reduce 50% filters", iterations)

        for n in range(iterations):
            print("Ranking filters.. ")
            prune_targets = self.get_candidates_to_prune(num_filters_to_prune_per_iteration)
            layers_prunned = {}
            for layer_index, filter_index in prune_targets:
                if layer_index not in layers_prunned:
                    layers_prunned[layer_index] = 0
                layers_prunned[layer_index] = layers_prunned[layer_index] + 1

            print("Layers that will be prunned", layers_prunned)
            print("Prunning filters.. ")
            model = self.model.cpu()

            skip = []
            for i, (layer_index, filter_index) in enumerate(prune_targets):
                print('[{}] - Pruning layer {} and filter_index {}'.format(i,layer_index,filter_index))
                if i in skip or filter_index < 0:
                    print('skipped pruning layer ', i)
                    continue

                model, update_pruned_layers = prune_xception_layer(model, layer_index, filter_index, self.prunner.flat_backbone, self.prunner.model_dict)

                # fix filters' indices
                for l, (l_index, f_index) in enumerate(prune_targets):
                    if l_index in update_pruned_layers and f_index >= filter_index:
                        if f_index == filter_index:
                            skip.append(l_index)
                        prune_targets[l] = (l_index, f_index-1)


            self.model = model
            if self.use_cuda:
                self.model = self.model.cuda()

            message = str(100 * float(self.total_num_filters()) / number_of_filters) + "%"
            print("Filters prunned", str(message))
            print("Fine tuning to recover from prunning iteration.")
            optimizer = optim.SGD(self.model.parameters(), lr=0.0001, momentum=0.9)
            m_time, Acc, Acc_class, mIoU, FWIoU = self.train(optimizer, epoches=10)
            epoch_times.append(m_time)
            # self.test()

            torch.save(self.model, "/home/ido/Deep/Pytorch/pytorch-deeplab-xception/pruned_models/iter:{0}_time:{1:.2f}_Acc:{2:.3f}_Acc_class:{3:.3f}_mIoU:{4:.3f}_fwIoU:{5:.3f}.pth".format(n+221,m_time*1000,Acc*100, Acc_class*100, mIoU*100, FWIoU*100))

        print(epoch_times)
        print("Finished. Going to fine tune the model a bit more")
        self.train(optimizer, epoches=15)
        torch.save(model.state_dict(), "model_prunned")
Exemple #10
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        if args.loss_type == 'depth_loss_two_distributions':
            self.nclass = args.num_class + args.num_class2 + 1
        if args.loss_type == 'depth_avg_sigmoid_class':
            self.nclass = args.num_class + args.num_class2
        # Define network

        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)


        print("\nDefine models...\n")
        self.model_aprox_depth = DeepLab(num_classes=1,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)

        self.input_conv = nn.Conv2d(4, 3, 3, padding=1)
        # Using cuda
        if args.cuda:
            self.model_aprox_depth = torch.nn.DataParallel(self.model_aprox_depth, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model_aprox_depth)
            self.model_aprox_depth = self.model_aprox_depth.cuda()
            self.input_conv = self.input_conv.cuda()


        print("\nLoad checkpoints...\n")
        if not args.cuda:
            ckpt_aprox_depth = torch.load(args.ckpt, map_location='cpu')
            self.model_aprox_depth.load_state_dict(ckpt_aprox_depth['state_dict'])
        else:
            ckpt_aprox_depth = torch.load(args.ckpt)
            self.model_aprox_depth.module.load_state_dict(ckpt_aprox_depth['state_dict'])


        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        # set optimizer
        optimizer = torch.optim.Adam(train_params, args.lr)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        if 'depth' in args.loss_type:
            self.criterion = DepthLosses(weight=weight,
                                         cuda=args.cuda,
                                         min_depth=args.min_depth,
                                         max_depth=args.max_depth,
                                         num_class=args.num_class,
                                         cut_point=args.cut_point,
                                         num_class2=args.num_class2).build_loss(mode=args.loss_type)
            self.infer = DepthLosses(weight=weight,
                                     cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.max_depth,
                                     num_class=args.num_class,
                                     cut_point=args.cut_point,
                                     num_class2=args.num_class2)
            self.evaluator_depth = EvaluatorDepth(args.batch_size)
        else:
            self.criterion = SegmentationLosses(cuda=args.cuda, weight=weight).build_loss(mode=args.loss_type)
            self.evaluator = Evaluator(self.nclass)

        self.model, self.optimizer = model, optimizer

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if 'depth' in args.loss_type:
            self.best_pred = 1e6
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            if not args.cuda:
                checkpoint = torch.load(args.resume, map_location='cpu')
            else:
                checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.module.load_state_dict(state_dict, strict=False)
            else:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.load_state_dict(state_dict, strict=False)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            if 'depth' in args.loss_type:
                self.best_pred = 1e6
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        # add input layer to the model
        self.model = nn.Sequential(
            self.input_conv,
            self.model
        )
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        self.model_aprox_depth.eval()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.dataset == 'apollo_seg' or self.args.dataset == 'farsight_seg':
                target[target <= self.args.cut_point] = 0
                target[target > self.args.cut_point] = 1
            if image.shape[0] == 1:
                target = torch.cat([target, target], dim=0)
                image = torch.cat([image, image], dim=0)
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            aprox_depth = self.model_aprox_depth(image)
            aprox_depth = self.infer.sigmoid(aprox_depth)
            input = torch.cat([image, aprox_depth], dim=1)
            output = self.model(input)
            if self.args.loss_type == 'depth_sigmoid_loss_inverse':
                loss = self.criterion(output, target, inverse=True)
            else:
                loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                target[
                    torch.isnan(target)] = 0  # change nan values to zero for display (handle warning from tensorboard)
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step,
                                             n_class=self.args.num_class)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.model_aprox_depth.eval()
        if 'depth' in self.args.loss_type:
            self.evaluator_depth.reset()
        else:
            softmax = nn.Softmax(1)
            self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                aprox_depth = self.model_aprox_depth(image)
                aprox_depth = self.infer.sigmoid(aprox_depth)
                input = torch.cat([image, aprox_depth], dim=1)
                output = self.model(input)
                loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
            if 'depth' in self.args.loss_type:
                if self.args.loss_type == 'depth_loss':
                    pred = self.infer.pred_to_continous_depth(output)
                elif self.args.loss_type == 'depth_avg_sigmoid_class':
                    pred = self.infer.pred_to_continous_depth_avg(output)
                elif self.args.loss_type == 'depth_loss_combination':
                    pred = self.infer.pred_to_continous_combination(output)
                elif self.args.loss_type == 'depth_loss_two_distributions':
                    pred = self.infer.pred_to_continous_depth_two_distributions(output)
                elif 'depth_sigmoid_loss' in self.args.loss_type:
                    output = self.infer.sigmoid(output.squeeze(1))
                    pred = self.infer.depth01_to_depth(output)
                # Add batch sample into evaluator
                self.evaluator_depth.evaluateError(pred, target)
            else:
                output = softmax(output)
                pred = output.data.cpu().numpy()
                target = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                # Add batch sample into evaluator
                self.evaluator.add_batch(target, pred)
        if 'depth' in self.args.loss_type:
            # Fast test during the training
            MSE = self.evaluator_depth.averageError['MSE']
            RMSE = self.evaluator_depth.averageError['RMSE']
            ABS_REL = self.evaluator_depth.averageError['ABS_REL']
            LG10 = self.evaluator_depth.averageError['LG10']
            MAE = self.evaluator_depth.averageError['MAE']
            DELTA1 = self.evaluator_depth.averageError['DELTA1']
            DELTA2 = self.evaluator_depth.averageError['DELTA2']
            DELTA3 = self.evaluator_depth.averageError['DELTA3']

            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/MSE', MSE, epoch)
            self.writer.add_scalar('val/RMSE', RMSE, epoch)
            self.writer.add_scalar('val/ABS_REL', ABS_REL, epoch)
            self.writer.add_scalar('val/LG10', LG10, epoch)

            self.writer.add_scalar('val/MAE', MAE, epoch)
            self.writer.add_scalar('val/DELTA1', DELTA1, epoch)
            self.writer.add_scalar('val/DELTA2', DELTA2, epoch)
            self.writer.add_scalar('val/DELTA3', DELTA3, epoch)

            print('Validation:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print(
                "MSE:{}, RMSE:{}, ABS_REL:{}, LG10: {}\nMAE:{}, DELTA1:{}, DELTA2:{}, DELTA3: {}".format(MSE, RMSE,
                                                                                                         ABS_REL,
                                                                                                         LG10, MAE,
                                                                                                         DELTA1,
                                                                                                         DELTA2,
                                                                                                         DELTA3))
            new_pred = RMSE
            if new_pred < self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
        else:
            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/mIoU', mIoU, epoch)
            self.writer.add_scalar('val/Acc', Acc, epoch)
            self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
            self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
            print('Validation:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
            print('Loss: %.3f' % test_loss)

            new_pred = mIoU
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

        print('Loss: %.3f' % test_loss)
Exemple #11
0
class Deeplab(object):
    def __init__(self, cfgfile):
        self.args = parse_cfg(cfgfile)

        self.nclass = int(self.args['nclass'])

        model = DeepLab(num_classes=self.nclass,
                        backbone=self.args['backbone'],
                        output_stride=int(self.args['out_stride']),
                        sync_bn=bool(self.args['sync_bn']),
                        freeze_bn=bool(self.args['freeze_bn']))

        weight = None

        self.criterion = SegmentationLosses(
            weight=weight, cuda=True).build_loss(mode=self.args['loss_type'])

        self.model = model
        self.evaluator = Evaluator(self.nclass)

        # Using cuda

        self.model = self.model.cuda()
        self.model = torch.nn.DataParallel(self.model, device_ids=[0])
        patch_replication_callback(self.model)
        self.resume = self.args['resume']

        # Resuming checkpoint
        if self.resume is not None:
            if not os.path.isfile(self.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    self.resume))
            checkpoint = torch.load(self.resume)

            self.model.module.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                self.resume, checkpoint['epoch']))

    def untransform(self, img, lbl=None):
        mean_bgr = np.array([0.485, 0.456, 0.406])
        std_bgr = np.array([0.229, 0.224, 0.225])

        img = img.numpy()
        img = img.transpose(1, 2, 0)
        img *= std_bgr
        img += mean_bgr
        img *= 255
        img = img.astype(np.uint8)

        lbl = lbl.numpy()
        return img, lbl

    def transform(self, sample):
        composed_transforms = transforms.Compose([
            tr.FixedResize(size=513),
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        return composed_transforms(sample)

    def validation(self, src, data, ClassName, tar=None):

        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0

        w, h = src.size
        sample = {'image': src, 'label': tar}
        sample = self.transform(sample)
        image, target = sample['image'], sample['label']

        # for the dimension
        image = torch.unsqueeze(image, 0)
        target = torch.unsqueeze(target, 0)
        image, target = image.cuda(), target.cuda()

        with torch.no_grad():
            output = self.model(image)

        imgs = image.data.cpu()
        lbl_pred = output.data.max(1)[1].cpu().numpy()[:, :, :]
        lbl_true = target.data.cpu()  #cpu()

        for img, lt, lp in zip(imgs, lbl_true, lbl_pred):

            img, lt = self.untransform(img, lt)

            viz = fcn.utils.visualize_segmentation(
                lbl_pred=lp,
                lbl_true=None,
                img=img,
                n_class=10,
                label_names=[
                    ' ', '11_pforceps', '12_mbforceps', '13_mcscissors',
                    '15_pcapplier', '18_pclip', '20_sxir', '19_mtclip',
                    '17_mtcapplier', '14_graspers'
                ])

            width = int(viz.shape[1] / 3 * 2)

            #Image.fromarray(viz[:,width:,:]).save(str(data)+'.jpg')
            def hide(plt):
                ax = plt.gca()
                ax.axes.xaxis.set_visible(False)
                ax.axes.yaxis.set_visible(False)

            plt.figure(figsize=(20, 10))
            plt.subplot(2, 2, 1)
            plt.imshow(src)
            plt.title('Image')
            hide(plt)

            plt.subplot(2, 2, 2)
            plt.imshow(tar)
            plt.title('SegmentationClass')
            hide(plt)

            plt.subplot(2, 2, 3)
            plt.imshow(
                Image.open('./dataset/output/SegmentationClassVisualization/' +
                           data + '.jpg'))
            plt.title('SegmentationVisuallization')
            hide(plt)

            plt.subplot(2, 2, 4)
            plt.imshow(
                cv2.resize(np.float32(Image.fromarray(viz[:, width:, :])) /
                           255, (w, h),
                           interpolation=cv2.INTER_LINEAR))
            plt.title('Prediction')
            hide(plt)
            if not os.path.isdir('result'):
                os.makedirs('result')
            plt.savefig('result/%s.png' % (data))

        if tar is not None:
            loss = self.criterion(output, target)
            test_loss += loss.item()
            #tbar.set_description('Test loss: %.3f' % (test_loss))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target[:, 145:815, :], pred[:,
                                                                 145:815, :])

            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            print('Validation:')
            #print('[Epoch: %d, numImages: %5d]' % (epoch, self.args.batch_size + image.data.shape[0]))
            print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
                Acc, Acc_class, mIoU, FWIoU))
            print('Loss: %.3f' % test_loss)

            plot_confusion_matrix(self.evaluator.confusion_matrix, ClassName)
            plt.savefig('result/%s_CM.png' % (data))

        #return viz[:,width:, :]

    def validation_matrix(self, src, data, tar):

        self.model.eval()
        #self.evaluator.reset()
        sample = {'image': src, 'label': tar}
        sample = self.transform(sample)
        image, target = sample['image'], sample['label']
        # for the dimension
        image = torch.unsqueeze(image, 0)
        target = torch.unsqueeze(target, 0)
        image = image.cuda()

        with torch.no_grad():
            output = self.model(image)

        pred = output.data.cpu().numpy()
        target = target.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        # Add batch sample into evaluator
        self.evaluator.add_batch(target[:, 145:815, :], pred[:, 145:815, :])
Exemple #12
0
class Trainer(object):

    def __init__(self, args, dataloaders):
        self.args = args
        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataloaders

    def setup_saver_and_summary(self, num_current_labeled_samples, samples, experiment_group=None, regions=None):

        self.saver = ActiveSaver(self.args, num_current_labeled_samples, experiment_group=experiment_group)
        self.saver.save_experiment_config()
        self.saver.save_active_selections(samples, regions)
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.num_current_labeled_samples = num_current_labeled_samples

    def initialize(self):

        args = self.args
        model = DeepLabAccuracyPredictor(num_classes=self.nclass, backbone=args.backbone, output_stride=args.out_stride,
                                         sync_bn=args.sync_bn, freeze_bn=args.freeze_bn, mc_dropout=False, enet=args.architecture == 'enet', symmetry=args.symmetry)

        train_params = model.get_param_list(args.lr, args.architecture == 'enet', args.symmetry)

        if args.optimizer == 'SGD':
            optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        if args.use_balanced_weights:
            weight = calculate_weights_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion_deeplab = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion_unet = SegmentationLosses(weight=torch.FloatTensor(
            [args.weight_wrong_label_unet, 1 - args.weight_wrong_label_unet]), cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.deeplab_evaluator = Evaluator(self.nclass)
        self.unet_evaluator = Evaluator(2)

        if args.use_lr_scheduler:
            self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader))
        else:
            self.scheduler = None

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        self.best_pred = 0.0

    def training(self, epoch, w_dl, w_un):

        train_loss = 0.0
        train_loss_unet = 0.0
        train_loss_deeplab = 0.0

        self.model.train()
        num_img_tr = len(self.train_loader)
        tbar = tqdm(self.train_loader, desc='\r')

        visualization_index = int(random.random() * len(self.train_loader))
        vis_img = None
        vis_tgt_dl = None
        vis_tgt_un = None
        vis_out_dl = None
        vis_out_un = None

        for i, sample in enumerate(tbar):
            image, deeplab_target = sample['image'], sample['label']

            if self.args.cuda:
                image, deeplab_target = image.cuda(), deeplab_target.cuda()
            if self.scheduler:
                self.scheduler(self.optimizer, i, epoch, self.best_pred)
                self.writer.add_scalar('train/learning_rate', self.scheduler.current_lr, i + num_img_tr * epoch)

            self.optimizer.zero_grad()
            deeplab_output, unet_output = self.model(image)
            unet_target = deeplab_output.argmax(1).squeeze() == deeplab_target.long()
            unet_target[deeplab_target == 255] = 255

            if i == visualization_index:
                vis_img = image.cpu()
                vis_tgt_dl = deeplab_target.cpu()
                vis_out_dl = deeplab_output.cpu()
                vis_tgt_un = unet_target.cpu()
                vis_out_un = unet_output.cpu()

            loss_deeplab = self.criterion_deeplab(deeplab_output, deeplab_target)
            loss_unet = self.criterion_unet(unet_output, unet_target)
            loss = w_dl * loss_deeplab + w_un * loss_unet
            loss.backward()
            self.optimizer.step()
            train_loss_deeplab += loss_deeplab.item()
            train_loss_unet += loss_unet.item()
            train_loss += loss.item()
            tbar.set_description('Train losses: %.2f(dl) + %.2f(un) = %.3f' %
                                 (train_loss_deeplab / (i + 1), train_loss_unet / (i + 1), train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter_dl', loss_deeplab.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/total_loss_iter_un', loss_unet.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

        self.summary.create_single_visualization(self.writer, f'train/run_{self.num_current_labeled_samples:04d}', self.args.dataset, vis_img, vis_tgt_dl, vis_out_dl, vis_tgt_un, vis_out_un, epoch)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        self.writer.add_scalar('train/total_loss_epoch_dl', train_loss_unet, epoch)
        self.writer.add_scalar('train/total_loss_epoch_un', train_loss_deeplab, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f (DeepLab) + %.3f (UNet) = %.3f' % (train_loss_deeplab, train_loss_unet, train_loss))
        print('BestPred: %.3f' % self.best_pred)

        self.writer.add_scalar('train/w_dl', w_dl, i + num_img_tr * epoch)
        self.writer.add_scalar('train/w_un', w_un, i + num_img_tr * epoch)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

        return train_loss

    def validation(self, epoch, w_dl, w_un):

        self.model.eval()
        self.deeplab_evaluator.reset()
        self.unet_evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        test_loss_unet = 0.0
        test_loss_deeplab = 0.0

        visualization_index = int(random.random() * len(self.val_loader))
        vis_img = None
        vis_tgt_dl = None
        vis_tgt_un = None
        vis_out_dl = None
        vis_out_un = None

        for i, sample in enumerate(tbar):
            image, deeplab_target = sample['image'], sample['label']

            if self.args.cuda:
                image, deeplab_target = image.cuda(), deeplab_target.cuda()

            with torch.no_grad():
                deeplab_output, unet_output = self.model(image)

            unet_target = deeplab_output.argmax(1).squeeze() == deeplab_target.long()
            unet_target[deeplab_target == 255] = 255

            if i == visualization_index:
                vis_img = image.cpu()
                vis_tgt_dl = deeplab_target.cpu()
                vis_out_dl = deeplab_output.cpu()
                vis_tgt_un = unet_target.cpu()
                vis_out_un = unet_output.cpu()

            loss_deeplab = self.criterion_deeplab(deeplab_output, deeplab_target)
            loss_unet = self.criterion_unet(unet_output, unet_target)
            loss = w_dl * loss_deeplab + w_un * loss_unet

            test_loss += loss.item()
            test_loss_unet += loss_unet.item()
            test_loss_deeplab += loss_deeplab.item()

            tbar.set_description('Test losses: %.2f(dl) + %.2f(un) = %.3f' %
                                 (test_loss_deeplab / (i + 1), test_loss_unet / (i + 1), test_loss / (i + 1)))
            pred = deeplab_output.data.cpu().numpy()
            deeplab_target = deeplab_target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.deeplab_evaluator.add_batch(deeplab_target, pred)
            self.unet_evaluator.add_batch(unet_target.cpu().numpy(), np.argmax(unet_output.cpu().numpy(), axis=1))

        # Fast test during the training
        Acc = self.deeplab_evaluator.Pixel_Accuracy()
        Acc_class = self.deeplab_evaluator.Pixel_Accuracy_Class()
        mIoU = self.deeplab_evaluator.Mean_Intersection_over_Union()
        FWIoU = self.deeplab_evaluator.Frequency_Weighted_Intersection_over_Union()
        UNetAcc = self.unet_evaluator.Pixel_Accuracy()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        self.writer.add_scalar('val/UNetAcc', UNetAcc, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, UNetAcc: {}".format(Acc, Acc_class, mIoU, FWIoU, UNetAcc))
        print('Loss: %.3f (DeepLab) + %.3f (UNet) = %.3f' % (test_loss_deeplab, test_loss_unet, test_loss))

        new_pred = mIoU
        is_best = False
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred

        # save every validation model (overwrites)
        self.saver.save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': self.model.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_pred': self.best_pred,
        }, is_best)

        return test_loss, mIoU, Acc, Acc_class, FWIoU, [vis_img, vis_tgt_dl, vis_out_dl, vis_tgt_un, vis_out_un]
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        # self.saver = Saver(args)
        # Recoder the running processing
        self.saver = Saver(args)
        sys.stdout = Logger(
            os.path.join(
                self.saver.experiment_dir,
                'log_train-%s.txt' % time.strftime("%Y-%m-%d-%H-%M-%S")))
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)
        if args.dataset == 'pairwise_lits':
            proxy_nclasses = self.nclass = 3
        elif args.dataset == 'pairwise_chaos':
            proxy_nclasses = 2 * self.nclass
        else:
            raise NotImplementedError

        # Define network
        model = ConsistentDeepLab(in_channels=3,
                                  num_classes=proxy_nclasses,
                                  pretrained=args.pretrained,
                                  backbone=args.backbone,
                                  output_stride=args.out_stride,
                                  sync_bn=args.sync_bn,
                                  freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        # optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(train_params,
                                     weight_decay=args.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            weights = calculate_weigths_labels(args.dataset, self.train_loader,
                                               proxy_nclasses)
        else:
            weights = None

        # Initializing loss
        print("Initializing loss: {}".format(args.loss_type))
        self.criterion = losses.init_loss(args.loss_type, weights=weights)

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, (sample1, sample2, proxy_label,
                sample_indices) in enumerate(tbar):
            image1, target1 = sample1['image'], sample1['label']
            image2, target2 = sample2['image'], sample2['label']
            if self.args.cuda:
                image1, target1 = image1.cuda(), target1.cuda()
                image2, target2 = image2.cuda(), target2.cuda()
                proxy_label = proxy_label.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image1, image2)
            loss = self.criterion(output, proxy_label)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                image = torch.cat((image1, image2), dim=-2)
                if len(proxy_label.shape) > 3:
                    output = output[:, 0:self.nclass]
                    proxy_label = torch.argmax(proxy_label[:, 0:self.nclass],
                                               dim=1)
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, proxy_label, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image1.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        val_time = 0
        for i, (sample1, sample2, proxy_label,
                sample_indices) in enumerate(tbar):
            image1, target1 = sample1['image'], sample1['label']
            image2, target2 = sample2['image'], sample2['label']
            if self.args.cuda:
                image1, target1 = image1.cuda(), target1.cuda()
                image2, target2 = image2.cuda(), target2.cuda()
                proxy_label = proxy_label.cuda()

            with torch.no_grad():
                start = time.time()
                output = self.model(image1, image2, is_val=True)
                end = time.time()
            val_time += end - start
            loss = self.criterion(output, proxy_label)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            proxy_label = proxy_label.cpu().numpy()

            # Add batch sample into evaluator
            if len(proxy_label.shape) > 3:
                pred = np.argmax(pred[:, 0:self.nclass], axis=1)
                proxy_label = np.argmax(proxy_label[:, 0:self.nclass], axis=1)
            else:
                pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(proxy_label, pred)

            if self.args.save_predict:
                self.saver.save_predict_mask(
                    pred, sample_indices, self.val_loader.dataset.data1_files)

        print("Val time: {}".format(val_time))
        print("Total paramerters: {}".format(
            sum(x.numel() for x in self.model.parameters())))
        if self.args.save_predict:
            namelist = []
            for fname in self.val_loader.dataset.data1_files:
                # namelist.append(fname.split('/')[-1].split('.')[0])
                _, name = os.path.split(fname)
                name = name.split('.')[0]
                namelist.append(name)
            file = gzip.open(
                os.path.join(self.saver.save_dir, 'namelist.pkl.gz'), 'wb')
            pickle.dump(namelist, file, protocol=-1)
            file.close()

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image1.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        dice = self.evaluator.Dice()
        # self.writer.add_scalar('val/Dice_1', dice[1], epoch)
        self.writer.add_scalar('val/Dice_2', dice[2], epoch)
        print("Dice:{}".format(dice))

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Exemple #14
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}

        parameters.set_saved_parafile_path(args.para)
        patch_w = parameters.get_digit_parameters("", "train_patch_width", None, 'int')
        patch_h = parameters.get_digit_parameters("", "train_patch_height", None, 'int')
        overlay_x = parameters.get_digit_parameters("", "train_pixel_overlay_x", None, 'int')
        overlay_y = parameters.get_digit_parameters("", "train_pixel_overlay_y", None, 'int')
        crop_height = parameters.get_digit_parameters("", "crop_height", None, 'int')
        crop_width = parameters.get_digit_parameters("", "crop_width", None, 'int')

        dataset = RemoteSensingImg(args.dataroot, args.list, patch_w, patch_h, overlay_x, overlay_y)

        #train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
        #                                           num_workers=args.workers, shuffle=True)
        train_length = int(len(dataset) * 0.9)
        validation_length = len(dataset) - train_length
	#print ("totol data len is %d , train_length is %d"%(len(train_loader),train_length))	
        [self.train_dataset, self.val_dataset] = torch.utils.data.random_split(dataset, (train_length, validation_length))
        print("len of train dataset is %d and val dataset is %d and total datalen is %d"%(len(self.train_dataset),len(self.val_dataset),len(dataset)))
        self.train_loader=torch.utils.data.DataLoader(self.train_dataset, batch_size=args.batch_size,num_workers=args.workers, shuffle=True,drop_last=True)
        self.val_loader=torch.utils.data.DataLoader(self.val_dataset, batch_size=args.batch_size,num_workers=args.workers, shuffle=True,drop_last=True)
        print("len of train loader is %d and val loader is %d"%(len(self.train_loader),len(self.val_loader)))
	#self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
	
        # Define network
        model = DeepLab(num_classes=1,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)




        # whether to use class balanced weights
        # if args.use_balanced_weights:
        #     classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
        #     if os.path.isfile(classes_weights_path):
        #         weight = np.load(classes_weights_path)
        #     else:
        #         weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
        #     weight = torch.from_numpy(weight.astype(np.float32))
        # else:
        #     weight = None





        # Define Criterion
        #self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)


        self.criterion=nn.BCELoss()

        if args.cuda:
            self.criterion=self.criterion.cuda()


        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(2)
        # Define lr scheduler
        print("lenght of train_loader is %d"%(len(self.train_loader)))
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            #self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {}) with best mIoU {}"
                  .format(args.resume, checkpoint['epoch'], checkpoint['best_pred']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
            self.best_pred=0

    def training(self, epoch):
        train_start_time=time.time()
        train_loss = 0.0
        self.model.train()
        #tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        print("start training at epoch %d, with the training length of %d"%(epoch,num_img_tr))
        for i, (x, y) in enumerate(self.train_loader):
            start_time=time.time()
            image, target = x, y
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            end_time=time.time()
            #tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            print('[The loss for iteration %d is %.3f and the time used is %.3f]'%(i+num_img_tr*epoch,loss.item(),end_time-start_time))
            # Show 10 * 3 inference results each epoch
            # if i % (num_img_tr // 10) == 0:
            #     global_step = i + num_img_tr * epoch
            #     self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        train_end_time=time.time()
        print('[Epoch: %d, numImages: %5d, time used : %.3f hour]' % (epoch, i * self.args.batch_size + image.data.shape[0],(train_end_time-train_start_time)/3600))
        print('Loss: %.3f' % (train_loss/len(self.train_loader)))
	
        with open(self.args.checkname+".train_out.txt", 'a') as log:
            out_massage='[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])
            log.writelines(out_massage+'\n')
            out_massage='Loss: %.3f' % (train_loss/len(self.train_loader))
            log.writelines(out_massage+'\n')
        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        time_val_start=time.time()
        self.model.eval()
        self.evaluator.reset()
        #tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, (x, y) in enumerate(self.val_loader):
            image, target = x,y
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            #tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            #pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            print("validate on the %d patch of total %d patch"%(i,len(self.val_loader)))
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        time_val_end=time.time()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d, time used: %.3f hour]' % (epoch, len(self.val_loader), (time_val_end-time_val_start)/3600))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Validation Loss: %.3f' % (test_loss/len((self.val_loader))))

        with open(self.args.checkname+".train_out.txt", 'a') as log:
            out_message='Validation:'
            log.writelines(out_message+'\n')
            out_message="Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)
            log.writelines(out_message+'\n')
            out_message='Validation Loss: %.3f' % (test_loss/len((self.val_loader)))
            log.writelines(out_message+'\n')
        new_pred = mIoU

        if new_pred > self.best_pred:
            print("saveing model")
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best,self.args.checkname)
            return False
        else:
            return True
class MyTrainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}


        if (args.dataset == "fashion_person"):

            train_set = fashion.FashionDataset(args, Path.db_root_dir("fashion_person"), mode='train',type = 'person')
            val_set = fashion.FashionDataset(args, Path.db_root_dir("fashion_person"), mode='test', type='person')
            self.nclass = train_set.nclass



            print("Train size {}, val size {}".format(len(train_set), len(val_set)))


            self.train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True,
                                   **kwargs)
            self.val_loader = DataLoader(dataset=val_set, batch_size=args.batch_size, shuffle=False,
                                   **kwargs)
            self.test_loader = None

            assert self.nclass == 2

        elif (args.dataset == "fashion_clothes"):
            train_set = fashion.FashionDataset(args, Path.db_root_dir("fashion_clothes"), mode='train', type='clothes')
            val_set = fashion.FashionDataset(args, Path.db_root_dir("fashion_clothes"), mode='test', type='clothes')
            self.nclass = train_set.nclass

            print("Train size {}, val size {}".format(len(train_set), len(val_set)))

            self.train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True,
                                           **kwargs)
            self.val_loader = DataLoader(dataset=val_set, batch_size=args.batch_size, shuffle=False,
                                         **kwargs)
            self.test_loader = None

            assert self.nclass == 7



        #self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        # model = DeepLab(num_classes=self.nclass,
        #                 backbone=args.backbone,
        #                 output_stride=args.out_stride,
        #                 sync_bn=args.sync_bn,
        #                 freeze_bn=args.freeze_bn)
        # Using original network to load pretrained and do fine tuning


        self.best_pred = 0.0

        if args.model == 'deeplabv3+':
            model = DeepLab(backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)

            # Loading pretrained VOC model
            if args.resume is not None:
                if not os.path.isfile(args.resume):
                    raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
                if args.cuda:
                    checkpoint = torch.load(args.resume)
                else:
                    checkpoint = torch.load(args.resume,map_location='cpu')
                args.start_epoch = checkpoint['epoch']

                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))

            #Freez the backbone
            if args.freeze_backbone:
                set_parameter_requires_grad(model.backbone, False)

            ######NEW DECODER######
            #Different type of FT
            if args.ft_type == 'decoder':
                set_parameter_requires_grad(model, False)
                model.decoder = build_decoder(self.nclass, 'resnet', nn.BatchNorm2d)
                train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                                {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

            elif args.ft_type == 'last_layer':
                set_parameter_requires_grad(model, False)
                model.decoder.last_conv[8] = nn.Conv2d(in_channels=256, out_channels=self.nclass, kernel_size=1)
                model.decoder.last_conv[8].reset_parameters()
                train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                                {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
            if args.ft_type == 'all':
                #Reset last layer, to generate output we want
                model.decoder.last_conv[8] = nn.Conv2d(in_channels=256, out_channels=self.nclass, kernel_size=1)
                model.decoder.last_conv[8].reset_parameters()

                train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                            {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]



            # Define Optimizer
            optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                        weight_decay=args.weight_decay, nesterov=args.nesterov)


        elif args.model == "unet":
            model = UNet(num_categories=self.nclass, num_filters=args.num_filters)

            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                        weight_decay=args.weight_decay)

        elif args.model == 'mydeeplab':

            model = My_DeepLab(num_classes=self.nclass, in_channels=3)
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                                        weight_decay=args.weight_decay, nesterov=args.nesterov)


        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
            print("weight is {}".format(weight))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.model, self.optimizer = model, optimizer




        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            # TODO, ADD PARALLEL SUPPORT (NEED SYNC BATCH)
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)
            self.model = self.model.cuda()

        args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def visulize_validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            #current_index_val_set
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

            #we have image, target, output on GPU
            #j, index of image in batch

            self.summary.visualize_pregt(self.writer, self.args.dataset, image, target, output, i)

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Visualizing:')
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Final Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

    def output_validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            #current_index_val_set
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

            #we have image, target, output on GPU
            #j, index of image in batch

            #image save
            self.summary.save_pred(self.args.dataset, output, i)

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Visualizing:')
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Final Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)


    def _load_model(self, path):
        if self.args.cuda:
            checkpoint = torch.load(path)
        else:
            checkpoint = torch.load(path, map_location='cpu')

        self.model.load_state_dict(checkpoint['state_dict'])


    def train_loop(self):
        try:
            for epoch in range(self.args.start_epoch, self.args.epochs):
                self.training(epoch)
                if not self.args.no_val and epoch % self.args.eval_interval == (self.args.eval_interval - 1):
                    self.validation(epoch)
        except KeyboardInterrupt:
            print('Early Stopping')
        finally:
            self.visulize_validation()
            self.writer.close()
class Trainer(object):
    def __init__(self, args, dataloaders, mc_dropout):
        self.args = args
        self.mc_dropout = mc_dropout
        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataloaders

    def setup_saver_and_summary(self,
                                num_current_labeled_samples,
                                samples,
                                experiment_group=None,
                                regions=None):

        self.saver = ActiveSaver(self.args,
                                 num_current_labeled_samples,
                                 experiment_group=experiment_group)
        self.saver.save_experiment_config()
        self.saver.save_active_selections(samples, regions)
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

    def initialize(self):

        args = self.args

        if args.architecture == 'deeplab':
            print('Using Deeplab')
            model = DeepLab(num_classes=self.nclass,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)
            train_params = [{
                'params': model.get_1x_lr_params(),
                'lr': args.lr
            }, {
                'params': model.get_10x_lr_params(),
                'lr': args.lr * 10
            }]
        elif args.architecture == 'enet':
            print('Using ENet')
            model = ENet(num_classes=self.nclass,
                         encoder_relu=True,
                         decoder_relu=True)
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        elif args.architecture == 'fastscnn':
            print('Using FastSCNN')
            model = FastSCNN(3, self.nclass)
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        if args.optimizer == 'SGD':
            optimizer = torch.optim.SGD(train_params,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay,
                                        nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            optimizer = torch.optim.Adam(train_params,
                                         weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        if args.use_balanced_weights:
            weight = calculate_weights_labels(args.dataset, self.train_loader,
                                              self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.evaluator = Evaluator(self.nclass)

        if args.use_lr_scheduler:
            self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                          args.epochs, len(self.train_loader))
        else:
            self.scheduler = None

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        self.best_pred = 0.0

    def training(self, epoch):

        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_loader)
        tbar = tqdm(self.train_loader, desc='\r')

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            if self.scheduler:
                self.scheduler(self.optimizer, i, epoch, self.best_pred)
                self.writer.add_scalar('train/learning_rate',
                                       self.scheduler.current_lr,
                                       i + num_img_tr * epoch)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)
        print('BestPred: %.3f' % self.best_pred)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

        return train_loss

    def validation(self, epoch):

        self.model.eval()
        self.evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        visualization_index = int(random.random() * len(self.val_loader))
        vis_img = None
        vis_tgt = None
        vis_out = None

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)

            if i == visualization_index:
                vis_img = image
                vis_tgt = target
                vis_out = output

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)

            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        is_best = False
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred

        # save every validation model (overwrites)
        self.saver.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

        return test_loss, mIoU, Acc, Acc_class, FWIoU, [
            vis_img, vis_tgt, vis_out
        ]
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        if args.densecrfloss >0:
            self.densecrflosslayer = DenseCRFLoss(weight=args.densecrfloss, sigma_rgb=args.sigma_rgb, sigma_xy=args.sigma_xy, scale_factor=args.rloss_scale)
            print(self.densecrflosslayer)
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_celoss = 0.0
        train_crfloss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        softmax = nn.Softmax(dim=1)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            croppings = (target!=254).float()
            target[target==254]=255
            # Pixels labeled 255 are those unlabeled pixels. Padded region are labeled 254.
            # see function RandomScaleCrop in dataloaders/custom_transforms.py for the detail in data preprocessing
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            
            celoss = self.criterion(output, target)
            
            if self.args.densecrfloss ==0:
                loss = celoss
            else:
                max_output = (max(torch.abs(torch.max(output)), 
                                  torch.abs(torch.min(output))))
                mean_output = torch.mean(torch.abs(output)).item()
                # std_output = torch.std(output).item()
                probs = softmax(output) # /max_output*4
                denormalized_image = denormalizeimage(sample['image'], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                densecrfloss = self.densecrflosslayer(denormalized_image,probs,croppings)
                if self.args.cuda:
                    densecrfloss = densecrfloss.cuda()
                loss = celoss + densecrfloss
                train_crfloss += densecrfloss.item()

                logits_copy = output.detach().clone().requires_grad_(True)
                max_output_copy = (max(torch.abs(torch.max(logits_copy)), 
                                  torch.abs(torch.min(logits_copy))))
                probs_copy = softmax(logits_copy) # /max_output_copy*4
                denormalized_image_copy = denormalized_image.detach().clone()
                croppings_copy = croppings.detach().clone()
                densecrfloss_copy = self.densecrflosslayer(denormalized_image_copy, probs_copy, croppings)

                @torch.no_grad()
                def add_grad_map(grad, plot_name):
                  if i % (num_img_tr // 10) == 0:
                    global_step = i + num_img_tr * epoch
                    batch_grads = torch.max(torch.abs(grad), dim=1)[0].detach().cpu().numpy()
                    color_imgs = []
                    for grad_img in batch_grads:
                        grad_img[0,0]=0
                        img = colorize(grad_img)[:,:,:3]
                        color_imgs.append(img)
                    color_imgs = torch.from_numpy(np.array(color_imgs).transpose([0, 3, 1, 2]))
                    grid_image = make_grid(color_imgs[:3], 3, normalize=False, range=(0, 255))
                    self.writer.add_image(plot_name, grid_image, global_step)

                output.register_hook(lambda grad: add_grad_map(grad, 'Grad Logits')) 
                probs.register_hook(lambda grad: add_grad_map(grad, 'Grad Probs')) 
                
                logits_copy.register_hook(lambda grad: add_grad_map(grad, 'Grad Logits Rloss')) 
                densecrfloss_copy.backward()

                if i % (num_img_tr // 10) == 0:
                  global_step = i + num_img_tr * epoch
                  img_entropy = torch.sum(-probs*torch.log(probs+1e-9), dim=1).detach().cpu().numpy()
                  color_imgs = []
                  for e in img_entropy:
                      e[0,0] = 0
                      img = colorize(e)[:,:,:3]
                      color_imgs.append(img)
                  color_imgs = torch.from_numpy(np.array(color_imgs).transpose([0, 3, 1, 2]))
                  grid_image = make_grid(color_imgs[:3], 3, normalize=False, range=(0, 255))
                  self.writer.add_image('Entropy', grid_image, global_step)

                  self.writer.add_histogram('train/total_loss_iter/logit_histogram', output, i + num_img_tr * epoch)
                  self.writer.add_histogram('train/total_loss_iter/probs_histogram', probs, i + num_img_tr * epoch)

                self.writer.add_scalar('train/total_loss_iter/rloss', densecrfloss.item(), i + num_img_tr * epoch)
                self.writer.add_scalar('train/total_loss_iter/max_output', max_output.item(), i + num_img_tr * epoch)
                self.writer.add_scalar('train/total_loss_iter/mean_output', mean_output, i + num_img_tr * epoch)


            loss.backward()
        
            self.optimizer.step()
            train_loss += loss.item()
            train_celoss += celoss.item()
            
            tbar.set_description('Train loss: %.3f = CE loss %.3f + CRF loss: %.3f' 
                             % (train_loss / (i + 1),train_celoss / (i + 1),train_crfloss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/total_loss_iter/ce', celoss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        #if self.args.no_val:
        if self.args.save_interval:
            # save checkpoint every interval epoch
            is_best = False
            if (epoch + 1) % self.args.save_interval == 0:
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best, filename='checkpoint_epoch_{}.pth.tar'.format(str(epoch+1)))


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            target[target==254]=255
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Exemple #18
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary

        # Define Dataloader
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        if DEBUG:
            print("get device: ", self.device)
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args)
        # Define network

        modelDeeplab = DeepLab3d(num_classes=self.nclass,
                                 backbone=args.backbone,
                                 output_stride=args.out_stride,
                                 sync_bn=args.sync_bn,
                                 freeze_bn=args.freeze_bn).cuda()
        Bilstm = BiLSTM(cube_D * cube_D * cube_D * 3,
                        cube_D * cube_D * cube_D * 3, 1).cuda()
        train_params = [{
            'params': modelDeeplab.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': modelDeeplab.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizerSGD = torch.optim.SGD(train_params,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay,
                                       nesterov=args.nesterov)
        optimizerADAM = torch.optim.Adam(Bilstm.parameters())
        # Define Criterion
        # whether to use class balanced weights

        #if args.use_balanced_weights:
        #    classes_weights_path = os.path.join(ROOT_PATH, args.dataset+'_classes_weights.npy')

        #    if os.path.isfile(classes_weights_path):
        #        weight = np.load(classes_weights_path)
        #    else:
        #        weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
        #    weight = torch.from_numpy(weight.astype(np.float32)) ##########weight not cuda

        #else:
        #    weight = None

        self.deeplabCriterion = DiceCELoss().cuda()
        self.lstmCost = torch.nn.BCELoss().cuda()
        self.deeplab, self.Bilstm, self.optimizerSGD, self.optimizerADAM = modelDeeplab, Bilstm, optimizerSGD, optimizerADAM

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler

        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda

        #if args.cuda: --------注释掉不会有问题的吧
        #    self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
        #    patch_replication_callback(self.model)
        #    self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.deeplab.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.deeplab.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):

        train_loss = 0.0
        dice_loss_count = 0.0
        ce_loss_count = 0.0
        num_count = 0

        self.deeplab.eval()

        tbar = tqdm(self.train_loader)

        num_img_tr = len(self.train_loader)

        for i, sample in enumerate(tbar):
            image, target = sample
            target_sque = target.squeeze(
            )  #期望得到没有 batch, channel的譬如50*384*384图像
            img_sque = image.squeeze()
            if DEBUG:
                print("image, target ,sque size feed in model,", image.size(),
                      target.size(), target_sque.size())
            image, target = image.cuda(), target.cuda()

            self.scheduler(self.optimizerSGD, i, epoch, self.best_pred)
            self.optimizerSGD.zero_grad()

            output = self.deeplab(image)
            if DEBUG:
                print(output.size())
            n, c, d, w, h = output.shape
            output2 = torch.tensor((np.zeros(
                (n, c, d, w, h))).astype(np.float32))
            if (output.is_cuda == True):
                output2 = output2.to(self.device)
            for mk1 in range(0, n):
                for mk2 in range(0, c):  #对于每个n, c进行正则化
                    output2[mk1, mk2, :, :, :] = (
                        output[mk1, mk2, :, :, :] -
                        torch.min(output[mk1, mk2, :, :, :])) / (
                            torch.max(output[mk1, mk2, :, :, :]) -
                            torch.min(output[mk1, mk2, :, :, :]))

            loss, dice_loss, ce_loss = self.deeplabCriterion(
                output, output2, target, self.device)

            loss.backward()

            self.optimizerSGD.step()
            #####---------------------------------lstm part---------------------
            aro = output2[0][0]
            aro = aro.detach().cpu().numpy()
            gro = output2[0][1]
            gro = gro.detach().cpu().numpy()  #要求batch必须是1
            orig_vol_dim, bbx_loc = get_bounding_box_loc(img=target_sque,
                                                         bbx_ext=10)
            aux_grid_list = load_nii2grid(grid_D,
                                          grid_ita,
                                          bbx_loc=bbx_loc,
                                          img=target_sque)  #读取label,
            aux_grid_list_c0 = load_nii2grid(grid_D,
                                             grid_ita,
                                             img=gro,
                                             bbx_loc=bbx_loc)  #ground
            aux_grid_list_c1 = load_nii2grid(grid_D,
                                             grid_ita,
                                             img=aro,
                                             bbx_loc=bbx_loc)  #arotia

            us_grid_list = load_nii2grid(grid_D,
                                         grid_ita,
                                         img=img_sque,
                                         bbx_loc=bbx_loc)  #rawimage
            label_grid_list = []
            for g in range(len(us_grid_list)):
                us_grid_vol = us_grid_list[g]  #rawimage

                aux_grid_vol = aux_grid_list[g]  #label
                aux_grid_vol_c0 = aux_grid_list_c0[g]  #ground
                aux_grid_vol_c1 = aux_grid_list_c1[g]  #arotia
                # serialization grid to sequence
                us_mat = partition_vol2grid2seq(
                    us_grid_vol, cube_D, cube_ita,
                    norm_fact=255.0)  #正则化rawimage并切分

                aux_mat = partition_vol2grid2seq(aux_grid_vol,
                                                 cube_D,
                                                 cube_ita,
                                                 norm_fact=1.0)  #  label切分
                aux_mat_c0 = partition_vol2grid2seq(aux_grid_vol_c0,
                                                    cube_D,
                                                    cube_ita,
                                                    norm_fact=1.0)  #found切分
                aux_mat_c1 = partition_vol2grid2seq(aux_grid_vol_c1,
                                                    cube_D,
                                                    cube_ita,
                                                    norm_fact=1.0)  # arotia切分
                feat_mat = np.concatenate((us_mat, aux_mat_c0, aux_mat_c1),
                                          axis=1)  #串联rawinage,ground,arotia
                #print(feat_mat.shape)
                feat_mat = torch.from_numpy(feat_mat)  #转换为torchtensor
                #feat_map=feat_mat.float()   #转换为float类型
                feat_mat = feat_mat.unsqueeze(0)  #增加维度匹配LSTM的轮子
                feat_mat = Variable(
                    feat_mat).float().cuda()  #切换为float类型,也许可以试试double?
                #feat_mat.unsqueeze(0)
                y_label_seq = self.Bilstm(feat_mat)  #喂进网络
                #print(y_label_seq.shape)
                self.optimizerADAM.zero_grad()
                aux_mat = torch.from_numpy(aux_mat)  #讲label换为tensor
                aux_mat = aux_mat.float().cuda()  #label换为浮点型
                lstmloss = self.lstmCost(y_label_seq, aux_mat)  #计算损失
                lstmloss.backward()
                self.optimizerADAM.step()
            #######------------------------------------------------------------------------------
            train_loss += loss.item() + lstmloss
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            dice_loss_count = dice_loss_count + dice_loss.item()
            ce_loss_count = ce_loss_count + ce_loss.item()
            num_count = num_count + 1

            # Show 10 * 3 inference results each epoch

            if i % (num_img_tr // 5) == 0:

                global_step = i + num_img_tr * epoch

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))

        print('Loss: %.3f, dice loss: %.3f, ce loss: %.3f' %
              (train_loss, dice_loss_count / num_count,
               ce_loss_count / num_count))  #maybe here is something wrong

        if self.args.no_val:

            # save checkpoint every epoch

            is_best = False

            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.deeplab.module.state_dict(),
                    'optimizerSGD': self.optimizerSGD.state_dict(),
                    'optimizerADAM': self.optimizerADAM.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):

        self.deeplab.eval()

        self.evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')

        test_loss = 0.0
        dice_loss = 0.0
        ce_loss = 0.0
        num_count = 0
        for i, sample in enumerate(tbar):
            image, target = sample
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.deeplab(image)
            n, c, d, w, h = output.shape
            output2 = torch.tensor((np.zeros(
                (n, c, d, w, h))).astype(np.float32))
            if (output.is_cuda == True):
                output2 = output2.to(self.device)
            for mk1 in range(0, n):
                for mk2 in range(0, c):  #对于每个n, c进行正则化
                    output2[mk1, mk2, :, :, :] = (
                        output[mk1, mk2, :, :, :] -
                        torch.min(output[mk1, mk2, :, :, :])) / (
                            torch.max(output[mk1, mk2, :, :, :]) -
                            torch.min(output[mk1, mk2, :, :, :]))

            loss, dice, ce = self.criterion(output, ioutput2, target,
                                            self.device)
            test_loss += loss.item()
            dice_loss += dice.item()
            ce_loss += ce.item()
            num_count += 1
            tbar.set_description(
                'Test loss: %.3f, dice loss: %.3f, ce loss: %.3f' %
                (test_loss /
                 (i + 1), dice_loss / num_count, ce_loss / num_count))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()

            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            #            if self.args.cuda:
            #                target, pred = torch.from_numpy(target).cuda(), torch.from_numpy(pred).cuda()
            self.evaluator.add_batch(np.squeeze(target), pred)

        # Fast test during the training

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print('Validation:')

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))

        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))

        print('Loss: %.3f' % test_loss, dice_loss, ce_loss)

        new_pred = mIoU

        if new_pred > self.best_pred:

            is_best = True

            self.best_pred = new_pred

            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.deeplab.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
            print("ltt save ckpt!")
Exemple #19
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.vs = Vs(args.dataset)

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        if self.args.norm == "gn":
            norm = gn
        elif self.args.norm == "bn":
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == "abn":
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        if self.args.model == "deeplabv3+":
            model = DeepLab(args=self.args,
                            num_classes=self.nclass,
                            freeze_bn=args.freeze_bn)
        elif self.args.model == "deeplabv3":
            model = DeepLabv3(
                Norm=args.norm,
                backbone=args.backbone,
                output_stride=args.out_stride,
                num_classes=self.nclass,
                freeze_bn=args.freeze_bn,
            )
        elif self.args.model == "fpn":
            model = FPN(args=args, num_classes=self.nclass)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + "_classes_weights.npy")
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])
            self.best_pred = checkpoint["best_pred"]
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def test(self):
        self.model.eval()
        self.args.examine = False
        tbar = tqdm(self.test_loader, desc="\r")
        if self.args.color:
            __image = True
        else:
            __image = False
        for i, sample in enumerate(tbar):
            images = sample["image"]
            names = sample["name"]
            if self.args.cuda:
                images = images.cuda()
            with torch.no_grad():
                output = self.model(images)
            preds = output.data.cpu().numpy()
            preds = np.argmax(preds, axis=1)
            if __image:
                images = images.cpu().numpy()
            if not self.args.color:
                self.vs.predict_id(preds, names, self.args.save_dir)
            else:
                self.vs.predict_color(preds, images, names, self.args.save_dir)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        if self.args.color or self.args.examine:
            __image = True
        else:
            __image = False
        for i, sample in enumerate(tbar):
            images, targets = sample["image"], sample["label"]
            names = sample["name"]
            if self.args.cuda:
                images, targets = images.cuda(), targets.cuda()
            with torch.no_grad():
                output = self.model(images)
            loss = self.criterion(output, targets)
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            preds = output.data.cpu().numpy()
            targets = targets.cpu().numpy()
            preds = np.argmax(preds, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(targets, preds)
            if __image:
                images = images.cpu().numpy()
            if self.args.id:
                self.vs.predict_id(preds, names, self.args.save_dir)
            if self.args.color:
                self.vs.predict_color(preds, images, names, self.args.save_dir)
            if self.args.examine:
                self.vs.predict_examine(preds, targets, images, names,
                                        self.args.save_dir)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Validation:")
        # print(
        #     "[Epoch: %d, numImages: %5d]"
        #     % (epoch, i * self.args.batch_size + image.data.shape[0])
        # )
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print("Loss: %.3f" % test_loss)
Exemple #20
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        #self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.printer = args.printer

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass, self.class_names = make_data_loader(
            args, **kwargs)

        # Define network
        self.model = self.get_net()
        if args.net in {
                'deeplabv3p', 'wdeeplabv3p', 'wsegnet', 'segnet', 'unet'
        }:
            train_params = [{
                'params': self.model.get_1x_lr_params(),
                'lr': args.lr
            }, {
                'params': self.model.get_10x_lr_params(),
                'lr': args.lr * 10
            }]
        elif args.net in {'segnet', 'waveunet', 'unet', 'waveunet_v2'}:
            weight_p, bias_p = [], []
            for name, p in self.model.named_parameters():
                if 'bias' in name:
                    bias_p.append(p)
                else:
                    weight_p.append(p)
            train_params = [{
                'params': weight_p,
                'weight_decay': args.weight_decay,
                'lr': args.lr
            }, {
                'params': bias_p,
                'weight_decay': 0,
                'lr': args.lr
            }]
        else:
            train_params = None
            assert args.net in {
                'deeplabv3p', 'wdeeplabv3p', 'wsegnet', 'segnet', 'unet'
            }

        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    nesterov=args.nesterov)
        self.optimizer = optimizer
        # Define Optimizer

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        #self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion = SegmentationLosses(
            weight=weight,
            cuda=args.cuda,
            batch_average=self.args.batch_average).build_loss(
                mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.args.printer.pprint(
            'Using {} LR Scheduler!, initialization lr = {}'.format(
                args.lr_scheduler, args.lr))
        if self.args.net.startswith('deeplab'):
            self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                          args.epochs, len(self.train_loader))
        else:
            self.scheduler = LR_Scheduler(args.lr_scheduler,
                                          args.lr,
                                          args.epochs,
                                          len(self.train_loader),
                                          net=self.args.net)

        for key, value in self.args.__dict__.items():
            if not key.startswith('_'):
                self.printer.pprint('{} ==> {}'.format(key.rjust(24), value))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu,
                                               output_device=args.out_gpu)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if args.dataset in ['pascal', 'cityscapes']:
                #self.load_pretrained_model()
                #elif args.dataset == 'cityscapes':
                self.load_pretrained_model_cityscape()

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def get_net(self):
        model = None
        if self.args.net == 'deeplabv3p':
            model = DeepLabV3P(num_classes=self.nclass,
                               backbone=self.args.backbone,
                               output_stride=self.args.out_stride,
                               sync_bn=self.args.sync_bn,
                               freeze_bn=self.args.freeze_bn,
                               p_dropout=self.args.p_dropout)
        elif self.args.net == 'wdeeplabv3p':
            model = WDeepLabV3P(num_classes=self.nclass,
                                backbone=self.args.backbone,
                                output_stride=self.args.out_stride,
                                sync_bn=self.args.sync_bn,
                                freeze_bn=self.args.freeze_bn,
                                wavename=self.args.wn,
                                p_dropout=self.args.p_dropout)
        elif self.args.net == 'segnet':
            model = SegNet(num_classes=self.nclass, wavename=self.args.wn)
        elif self.args.net == 'unet':
            model = UNet(num_classes=self.nclass, wavename=self.args.wn)
        elif self.args.net == 'wsegnet':
            model = WSegNet(num_classes=self.nclass, wavename=self.args.wn)
        return model

    def load_pretrained_model(self):
        if not os.path.isfile(self.args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args.resume))
        checkpoint = torch.load(self.args.resume,
                                map_location=self.args.gpu_map)
        try:
            self.args.start_epoch = checkpoint['epoch']
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            pre_model = checkpoint['state_dict']
        except:
            self.printer.pprint('What happened ?!')
            self.args.start_epoch = 0
            if self.args.net in {
                    'deeplabv3p', 'wdeeplabv3p', 'wsegnet', 'segnet', 'unet'
            }:
                pre_model = checkpoint
            elif self.args.net in {'waveunet', 'waveunet_v2'}:
                pre_model = checkpoint['state_dict']
        model_dict = self.model.state_dict()
        self.printer.pprint("=> loaded checkpoint '{}' (epoch {})".format(
            self.args.resume, self.args.start_epoch))
        for key in model_dict:
            self.printer.pprint('AAAA - key in model --> {}'.format(key))
        for key in pre_model:
            self.printer.pprint('BBBB - key in pre_model --> {}'.format(key))
        if self.args.net in {'deeplabv3p', 'wdeeplabv3p'}:
            pre_layers = [('module.' + k, v) for k, v in pre_model.items()
                          if 'module.' + k in model_dict]
            for key in pre_layers:
                self.printer.pprint('CCCC - key in pre_model --> {}'.format(
                    key[0]))
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)
        elif self.args.net in {'segnet', 'unet'}:
            pre_layers = [('module.' + k, v) for k, v in pre_model.items()
                          if 'module.' + k in model_dict]
            for key in pre_layers:
                self.printer.pprint('CCCC - key in pre_model --> {}'.format(
                    key[0]))
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)
        elif self.args.net in {'wsegnet'}:
            pre_layers = [('module.features.' + k[16:], v)
                          for k, v in pre_model.items()
                          if 'module.features.' + k[16:] in model_dict]
            for key in pre_layers:
                self.printer.pprint('CCCC - key in pre_model --> {}'.format(
                    key[0]))
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)
        elif self.args.net == 'wdeeplabv3p':
            pre_layers = [
                ('module.backbone.' + k[7:], v) for k, v in pre_model.items()
                if 'module.backbone.' + k[7:] in model_dict and (
                    v.shape == model_dict['module.backbone.' + k[7:]].shape)
            ]
            for key in pre_layers:
                self.printer.pprint('CCCC - key in pre_model --> {}'.format(
                    key[0]))
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)

    def load_pretrained_model_cityscape(self):
        if not os.path.isfile(self.args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args.resume))
        checkpoint = torch.load(self.args.resume,
                                map_location=self.args.gpu_map)
        try:
            self.args.start_epoch = 0
            pre_model = checkpoint['state_dict']
        except:
            self.printer.pprint('What happened ?!')
            self.args.start_epoch = 0
            pre_model = checkpoint
        self.printer.pprint("=> loaded checkpoint '{}' (epoch {})".format(
            self.args.resume, self.args.start_epoch))
        if self.args.net == 'deeplabv3p' or 'wdeeplabv3p_per':
            model_dict = self.model.state_dict()
            for key in model_dict:
                self.printer.pprint('AAAA - key in model --> {}'.format(key))
            for key in pre_model:
                self.printer.pprint(
                    'BBBB - key in pre_model --> {}'.format(key))
            pre_layers = [
                ('module.backbone.' + k, v) for k, v in pre_model.items()
                if 'module.backbone.' +
                k in model_dict and model_dict['module.backbone.' +
                                               k].shape == pre_model[k].shape
            ]
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_loader)
        time_epoch_begin = datetime.now()
        for i, sample in enumerate(self.train_loader):
            time_iter_begin = datetime.now()
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            time_iter_end = datetime.now()
            time_iter = time_iter_end - time_iter_begin
            time_iter_during = time_iter_end - self.args.time_begin
            if i % 10 == 0:
                self.printer.pprint('train: epoch = {:3d} / {:3d}, '
                                    'iter = {:4d} / {:5d}, '
                                    'loss = {:.3f} / {:.3f}, '
                                    'time = {} / {}, '
                                    'lr = {:.6f}'.format(
                                        epoch, self.args.epochs, i, num_img_tr,
                                        loss.item(), train_loss / (i + 1),
                                        time_iter, time_iter_during,
                                        self.optimizer.param_groups[0]['lr']))
        self.printer.pprint(
            '------------ Train_total_loss = {}, epoch = {}, Time = {}'.format(
                train_loss, epoch,
                datetime.now() - time_epoch_begin))
        self.printer.pprint(' ')
        if epoch % 10 == 0:
            filename = os.path.join(self.args.weight_root,
                                    'epoch_{}'.format(epoch) + '.pth.tar')
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, filename)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        num_img_val = len(self.val_loader)
        test_loss = 0.0
        time_epoch_begin = datetime.now()
        for i, sample in enumerate(self.val_loader):
            time_iter_begin = datetime.now()
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            #test_loss += loss.item()
            test_loss += loss
            _, pred = output.topk(1, dim=1)
            pred = pred.squeeze(dim=1)
            pred = pred.cpu().numpy()
            target = target.cpu().numpy()
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            time_iter_end = datetime.now()
            time_iter = time_iter_end - time_iter_begin
            time_iter_during = time_iter_end - self.args.time_begin
            self.printer.pprint('validation: epoch = {:3d} / {:3d}, '
                                'iter = {:4d} / {:5d}, '
                                'loss = {:.3f} / {:.3f}, '
                                'time = {} / {}'.format(
                                    epoch, self.args.epochs, i, num_img_val,
                                    loss.item(), test_loss / (i + 1),
                                    time_iter, time_iter_during))

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.printer.pprint('Validation, epoch = {}, Time = {}'.format(
            epoch,
            datetime.now() - time_epoch_begin))
        self.printer.pprint('------------ Total_loss = {}'.format(test_loss))
        self.printer.pprint(
            "------------ Acc: {:.4f}, mIoU: {:.4f}, fwIoU: {:.4f}".format(
                Acc, mIoU, FWIoU))
        self.printer.pprint('------------ Acc_class = {}'.format(Acc_class))
        Object_names = '\t'.join(self.class_names)
        Object_IoU = '\t'.join(
            ['{:0.3f}'.format(IoU * 100) for IoU in self.evaluator.IoU_class])
        self.printer.pprint('------------ ' + Object_names)
        self.printer.pprint('------------ ' + Object_IoU)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Exemple #21
0
class Trainer:

    def __init__(self, args, model, train_set, val_set, test_set, class_weights, saver):
        self.args = args
        self.saver = saver
        self.saver.save_experiment_config()
        self.train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        self.val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        self.test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        self.train_summary = TensorboardSummary(os.path.join(self.saver.experiment_dir, "train"))
        self.train_writer = self.train_summary.create_summary()
        self.val_summary = TensorboardSummary(os.path.join(self.saver.experiment_dir, "validation"))
        self.val_writer = self.val_summary.create_summary()
        self.model = model
        self.dataset_size = {'train': len(train_set), 'val': len(val_set), 'test': len(test_set)}

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        if args.use_balanced_weights:
            weight = torch.from_numpy(class_weights.astype(np.float32))
        else:
            weight = None

        if args.optimizer == 'SGD':
            print('Using SGD')
            self.optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            print('Using Adam')
            self.optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        self.lr_scheduler = None
        if args.use_lr_scheduler:
            if args.lr_scheduler == 'step':
                print('Using step lr scheduler')                
                self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[int(x) for x in args.step_size.split(",")], gamma=0.1)

        self.criterion = SegmentationLosses(weight=weight, ignore_index=255, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.evaluator = Evaluator(train_set.num_classes)
        self.best_pred = 0.0

    def training(self, epoch):

        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_dataloader)
        tbar = tqdm(self.train_dataloader, desc='\r')

        visualization_index = int(random.random() * len(self.train_dataloader))
        vis_img, vis_tgt, vis_out = None, None, None

        self.train_writer.add_scalar('learning_rate', get_learning_rate(self.optimizer), epoch)

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.train_writer.add_scalar('total_loss_iter', loss.item(), i + num_img_tr * epoch)

            if i == visualization_index:
                vis_img, vis_tgt, vis_out = image, target, output

        self.train_writer.add_scalar('total_loss_epoch', train_loss / self.dataset_size['train'], epoch)
        if constants.VISUALIZATION:
            self.train_summary.visualize_state(self.train_writer, self.args.dataset, vis_img, vis_tgt, vis_out, epoch)

        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)
        print('BestPred: %.3f' % self.best_pred)

    def validation(self, epoch, test=False):
        self.model.eval()
        self.evaluator.reset()
        
        ret_list = []
        if test:
            tbar = tqdm(self.test_dataloader, desc='\r')
        else:
            tbar = tqdm(self.val_dataloader, desc='\r')
        test_loss = 0.0

        visualization_index = int(random.random() * len(self.val_dataloader))
        vis_img, vis_tgt, vis_out = None, None, None

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)

            if i == visualization_index:
                vis_img, vis_tgt, vis_out = image, target, output

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = torch.argmax(output, dim=1).data.cpu().numpy()
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)
            
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU_20 = self.evaluator.Mean_Intersection_over_Union_20()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        if not test:
            self.val_writer.add_scalar('total_loss_epoch', test_loss / self.dataset_size['val'], epoch)
            self.val_writer.add_scalar('mIoU', mIoU, epoch)
            self.val_writer.add_scalar('mIoU_20', mIoU_20, epoch)
            self.val_writer.add_scalar('Acc', Acc, epoch)
            self.val_writer.add_scalar('Acc_class', Acc_class, epoch)
            self.val_writer.add_scalar('fwIoU', FWIoU, epoch)
            if constants.VISUALIZATION:
                self.val_summary.visualize_state(self.val_writer, self.args.dataset, vis_img, vis_tgt, vis_out, epoch)

        print("Test: " if test else "Validation:")
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, mIoU_20:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, mIoU_20, FWIoU))
        print('Loss: %.3f' % test_loss)

        if not test:
            new_pred = mIoU
            if new_pred > self.best_pred:
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                })

        return test_loss, mIoU, mIoU_20, Acc, Acc_class, FWIoU#, ret_list

    def load_best_checkpoint(self):
        checkpoint = self.saver.load_checkpoint()
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        print(f'=> loaded checkpoint - epoch {checkpoint["epoch"]})')
        return checkpoint["epoch"]
Exemple #22
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        if args.distributed:
            if args.local_rank ==0:
                self.saver = Saver(args)
                self.saver.save_experiment_config()
                # Define Tensorboard Summary
                self.summary = TensorboardSummary(self.saver.experiment_dir)
                self.writer = self.summary.create_summary()
        else:
            self.saver = Saver(args)
            self.saver.save_experiment_config()
            # Define Tensorboard Summary
            self.summary = TensorboardSummary(self.saver.experiment_dir)
            self.writer = self.summary.create_summary()

        # PATH = args.path
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        # self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = SCNN(nclass=self.nclass,backbone=args.backbone,output_stride=args.out_stride,cuda = args.cuda,extension=args.ext)


        # Define Optimizer
        # optimizer = torch.optim.SGD(model.parameters(),args.lr, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(model.parameters(), args.lr,weight_decay=args.weight_decay)

        # model, optimizer = amp.initialize(model,optimizer,opt_level="O1")

        # Define Criterion
        weight = None
        # criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = FocalLoss(gamma=0, alpha=[0.2, 0.98], img_size=512*512)
        self.criterion1 = FocalLoss(gamma=5, alpha=[0.2, 0.98], img_size=512 * 512)
        self.criterion2 = disc_loss(delta_v=0.5, delta_d=3.0, param_var=1.0, param_dist=1.0,
                                    param_reg=0.001, EMBEDDING_FEATS_DIMS=21,image_shape=[512,512])

        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.val_loader),local_rank=args.local_rank)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()
            if args.distributed:
                self.model = DistributedDataParallel(self.model)
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)


        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            filename = 'checkpoint.pth.tar'
            args.resume = os.path.join(self.saver.experiment_dir, filename)
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # if args.cuda:
            #     self.model.module.load_state_dict(checkpoint['state_dict'])
            # else:
            self.model.load_state_dict(checkpoint['state_dict'])
            # if not args.ft:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        max_instances = 1
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target, ins_target = sample['image'], sample['bin_label'], sample['label']
            # _target = target.cpu().numpy()
            # if np.max(_target) > max_instances:
            #     max_instances = np.max(_target)
            #     print(max_instances)

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)

            # if i % 10==0:
            #     misc.imsave('/mfc/user/1623600/.temp6/train_{:s}_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),np.transpose(image[0].cpu().numpy(),(1,2,0)))
            #     os.chmod('/mfc/user/1623600/.temp6/train_{:s}_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),0o777)


            # self.criterion = DataParallelCriterion(self.criterion)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, ins_target)

            reg_lambda = 0.01


            loss = loss1 + 10*loss2
            # loss = loss1
            output=output[1]
            # loss.back
            # with amp.scale_loss(loss, self.optimizer) as scaled_loss:
            #     scaled_loss.backward()

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            if self.args.distributed:
                if self.args.local_rank == 0:
                    self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            else:
                self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)


            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr / 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.distributed:
                    if self.args.local_rank == 0:
                        self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
                else:
                    self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        if self.args.distributed:
            if self.args.local_rank == 0:
                self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        else:
            self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)

        if self.args.local_rank == 0:
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print('Loss: %.3f' % train_loss)

        # if self.args.distributed:
        #     if self.args.local_rank == 0:
        #         if self.args.no_val:
        #             # save checkpoint every epoch
        #             is_best = False
        #             self.saver.save_checkpoint({
        #                 'epoch': epoch + 1,
        #                 'state_dict': self.model.module.state_dict(),
        #                 'optimizer': self.optimizer.state_dict(),
        #                 'best_pred': self.best_pred,
        #             }, is_best)
        #     else:
        #         if self.args.no_val:
        #             # save checkpoint every epoch
        #             is_best = False
        #             self.saver.save_checkpoint({
        #                 'epoch': epoch + 1,
        #                 'state_dict': self.model.module.state_dict(),
        #                 'optimizer': self.optimizer.state_dict(),
        #                 'best_pred': self.best_pred,
        #             }, is_best)



    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target = sample['image'], sample['bin_label']
            a= target.numpy()
            aa_max = np.max(a)
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion1(output, target)
            test_loss += loss.item()
            instance_seg = output[0].data.cpu().numpy()
            instance_seg = np.squeeze(instance_seg[0])
            instance_seg = np.transpose(instance_seg, (1, 2, 0))
            output = output[1]
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            if i % 30==0:
                misc.imsave('/mfc/user/1623600/.temp6/{:s}_val_epoch:{}_i:{}.png'
                            .format(str(self.args.distributed),epoch,i),
                            np.transpose(image[0].cpu().numpy(),(1,2,0))+3*np.asarray(np.stack((pred[0],pred[0],pred[0]),axis=-1),dtype=np.uint8))
                os.chmod('/mfc/user/1623600/.temp6/{:s}_val_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),0o777)
                temp_instance_seg = np.zeros_like(np.transpose(image[0].cpu().numpy(),(1,2,0)))
                for j in range(21):
                    if j<7:
                        temp_instance_seg[:, :, 0] += instance_seg[:, :, j]
                    elif j<14:
                        temp_instance_seg[:, :, 1] += instance_seg[:, :, j]
                    else:
                        temp_instance_seg[:, :, 2] += instance_seg[:, :, j]

                for k in range(3):
                    temp_instance_seg[:, :, k] = self.minmax_scale(temp_instance_seg[:, :, k])

                instance_seg = np.array(temp_instance_seg, np.uint8)


                misc.imsave('/mfc/user/1623600/.temp6/emb_{:s}_val_epoch:{}_i:{}.png'
                            .format(str(self.args.distributed), epoch, i),instance_seg[...,:3])
                os.chmod(
                    '/mfc/user/1623600/.temp6/emb_{:s}_val_epoch:{}_i:{}.png'.format(str(self.args.distributed), epoch, i),
                    0o777)



        if self.args.distributed:
            if self.args.local_rank == 0:
                # Fast test during the training
                Acc = self.evaluator.Pixel_Accuracy()
                Acc_class = self.evaluator.Pixel_Accuracy_Class()
                mIoU = self.evaluator.Mean_Intersection_over_Union()
                F0 = self.evaluator.F0()
                self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
                self.writer.add_scalar('val/mIoU', mIoU, epoch)
                self.writer.add_scalar('val/Acc', Acc, epoch)
                self.writer.add_scalar('val/Acc_class', Acc_class, epoch)

                if self.args.local_rank == 0:
                    print('Validation:')
                    print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
                    print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
                    print('Loss: %.3f' % test_loss)

                new_pred = F0
                if new_pred > self.best_pred:
                    is_best = True
                    self.best_pred = new_pred
                    self.saver.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)
        else:
            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            F0 = self.evaluator.F0()
            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/mIoU', mIoU, epoch)
            self.writer.add_scalar('val/Acc', Acc, epoch)
            self.writer.add_scalar('val/Acc_class', Acc_class, epoch)

            if self.args.local_rank == 0:
                print('Validation:')
                print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
                print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
                print('Loss: %.3f' % test_loss)

            new_pred = F0
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def minmax_scale(self,input_arr):
        """

        :param input_arr:
        :return:
        """
        min_val = np.min(input_arr)
        max_val = np.max(input_arr)

        output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

        return output_arr
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        if args.dataset == 'CamVid':
            size = 512
            train_file = os.path.join(os.getcwd() + "\\data\\CamVid", "train.csv")
            val_file = os.path.join(os.getcwd() + "\\data\\CamVid", "val.csv")
            print('=>loading datasets')
            train_data = CamVidDataset(csv_file=train_file, phase='train')
            self.train_loader = torch.utils.data.DataLoader(train_data,
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.num_workers)
            val_data = CamVidDataset(csv_file=val_file, phase='val', flip_rate=0)
            self.val_loader = torch.utils.data.DataLoader(val_data,
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.num_workers)
            self.num_class = 32
        elif args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2,4,23,3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'CamVid':
            self.criterion = nn.CrossEntropyLoss()
        elif args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume:
            output_dir = os.path.join(args.save_dir, args.dataset, args.checkname)
            runs = sorted(glob.glob(os.path.join(output_dir, 'experiment_*')))
            run_id = int(runs[-1].split('_')[-1]) - 1 if runs else 0
            experiment_dir = os.path.join(output_dir, 'experiment_{}'.format(str(run_id)))
            load_name = os.path.join(experiment_dir,
                                 'checkpoint.pth.tar')
            if not os.path.isfile(load_name):
                raise RuntimeError("=> no checkpoint found at '{}'".format(load_name))
            checkpoint = torch.load(load_name)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.lr = checkpoint['optimizer']['param_groups'][0]['lr']
            print("=> loaded checkpoint '{}'(epoch {})".format(load_name, checkpoint['epoch']))

        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        # tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        if self.lr_staget_ind > 1 and epoch % (self.lr_stage[self.lr_staget_ind]) == 0:
            adjust_learning_rate(self.optimizer, self.args.lr_decay_gamma)
            self.lr *= self.args.lr_decay_gamma
            self.lr_staget_ind += 1
        for iteration, batch in enumerate(self.train_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            inputs = Variable(image)
            labels = Variable(target)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels.long())
            loss_val = loss.item()
            loss.backward(torch.ones_like(loss))
            # loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            # tbar.set_description('\rTrain loss:%.3f' % (train_loss / (iteration + 1)))

            if iteration % 10 == 0:
                print("Epoch[{}]({}/{}):Loss:{:.4f}, learning rate={}".format(epoch, iteration, len(self.train_loader), loss.data, self.lr))

            self.writer.add_scalar('train/total_loss_iter', loss.item(), iteration + num_img_tr * epoch)

            #if iteration % (num_img_tr // 10) == 0:
            #    global_step = iteration + num_img_tr * epoch
            #    self.summary.visualize_image(self.witer, self.args.dataset, image, target, outputs, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for iter, batch in enumerate(self.val_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f ' % (test_loss / (iter + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/FWIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iter * self.args.batch_size + image.shape[0]))
        print("Acc:{:.5f}, Acc_class:{:.5f}, mIoU:{:.5f}, fwIoU:{:.5f}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        # kwargs = {'num_workers': args.workers, 'pin_memory': True}
        kwargs = {'num_workers': 0, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.nir:
            input_channels = 4
        else:
            input_channels = 3

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        in_channels=input_channels,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
                weight[1] = 4
                weight[2] = 2
                weight[0] = 1
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                # place_holder_target = target
                # place_holder_output = output
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def pred_single_image(self, path):
        self.model.eval()
        img_path = path
        lbl_path = os.path.join(
            os.path.split(os.path.split(path)[0])[0], 'lbl',
            os.path.split(path)[1])
        activations = collections.defaultdict(list)

        def save_activation(name, mod, input, output):
            activations[name].append(output.cpu())

        for name, m in self.model.named_modules():
            if type(m) == nn.ReLU:
                m.register_forward_hook(partial(save_activation, name))

        input = cv2.imread(path)
        label = cv2.imread(lbl_path)
        # bkg = cv2.createBackgroundSubtractorMOG2()
        # back = bkg.apply(input)
        # cv2.imshow('back', back)
        # cv2.waitKey()
        input = cv2.resize(input, (513, 513), interpolation=cv2.INTER_CUBIC)
        image = Image.open(img_path).convert('RGB')  # width x height x 3
        # _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _tmp = np.array(Image.open(img_path), dtype=np.uint8)
        _tmp[_tmp == 255] = 1
        _tmp[_tmp == 0] = 0
        _tmp[_tmp == 128] = 2
        _tmp = Image.fromarray(_tmp)

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=513),
            tr.Normalize(mean=mean, std=std),
            tr.ToTensor()
        ])
        sample = {'image': image, 'label': _tmp}
        sample = composed_transforms(sample)

        image, target = sample['image'], sample['label']

        image = torch.unsqueeze(image, dim=0)
        if self.args.cuda:
            image, target = image.cuda(), target.cuda()
        with torch.no_grad():
            output = self.model(image)
            # output = output.data.cpu().numpy().squeeze(0).transpose([1, 2, 0])

            # output = np.argmax(output, axis=2) * 255
            output = output.data.cpu().numpy()
            prediction = np.argmax(output, axis=1)
            prediction = np.squeeze(prediction, axis=0)
            prediction[prediction == 1] = 255
            if np.any(prediction == 2):
                prediction[prediction == 2] = 128
            if np.any(prediction == 1):
                prediction[prediction == 1] = 255
            print(np.unique(prediction))

        see = Analysis(activations, label=1, path=self.saver.experiment_dir)
        see.backtrace(output)
        # for key in keys:
        #
        #     see.visualize_tensor(see.image)
        # see.save_tensor(see.image, self.saver.experiment_dir)

        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'rgb.png'), input)
        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'lbl.png'), label)
        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'prediction.png'),
                    prediction)
        # pred = output.data.cpu().numpy()
        # target = target.cpu().numpy()
        # pred = np.argmax(pred, axis=1)
        # pred = np.reshape(pred, (513, 513))
        # # prediction = np.append(target, pred, axis=1)
        # prediction = pred
        #
        # rgb = np.zeros((prediction.shape[0], prediction.shape[1], 3))
        #
        # r = prediction.copy()
        # g = prediction.copy()
        # b = prediction.copy()
        #
        # g[g != 1] = 0
        # g[g == 1] = 255
        #
        # r[r != 2] = 0
        # r[r == 2] = 255
        # b = np.zeros(b.shape)
        #
        # rgb[:, :, 0] = b
        # rgb[:, :, 1] = g
        # rgb[:, :, 2] = r
        #
        # prediction = np.append(input, rgb.astype(np.uint8), axis=1)
        # result = np.append(input, prediction.astype(np.uint8), axis=1)
        # cv2.line(rgb, (513, 0), (513, 1020), (255, 255, 255), thickness=1)
        # cv2.line(rgb, (513, 0), (513, 1020), (255, 255, 255), thickness=1)
        # cv2.imwrite('/home/robot/git/pytorch-deeplab-xception/run/cropweed/deeplab-resnet/experiment_41/samples/synthetic_{}.png'.format(counter), prediction)
        # plt.imshow(see.weed_filter)
        # # cv2.waitKey()
        # plt.show()

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def testing(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.test_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)
            # Add batch sample into evaluator
            prediction = np.append(target, pred, axis=2)
            print(pred.shape)
            input = image[0, 0:3, :, :].cpu().numpy().transpose([1, 2, 0])
            # cv2.imshow('figure', prediction)
            # cv2.waitKey()

        # Fast test during the testing
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print(
            '[INFO] Network performance measures on the test dataset are as follows: \n '
            'mIOU: {} \n FWIOU: {} \n Class accuracy: {} \n Pixel Accuracy: {}'
            .format(mIoU, FWIoU, Acc_class, Acc))

        self.evaluator.per_class_accuracy()

    def explain_image(self, path, counter):
        self.model.eval()
        img_path = path
        lbl_path = os.path.join(
            os.path.split(os.path.split(path)[0])[0], 'lbl',
            os.path.split(path)[1])
        image = Image.open(img_path).convert('RGB')  # width x height x 3
        _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _tmp[_tmp == 255] = 1
        _tmp[_tmp == 0] = 0
        _tmp[_tmp == 128] = 2
        _tmp = Image.fromarray(_tmp)

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=513),
            tr.Normalize(mean=mean, std=std),
            tr.ToTensor()
        ])
        sample = {'image': image, 'label': _tmp}
        sample = composed_transforms(sample)

        image, target = sample['image'], sample['label']

        image = torch.unsqueeze(image, dim=0)
        # if self.args.cuda:
        #     image, target = image.cuda(), target.cuda()
        # with torch.no_grad():
        #     output = self.model(image)
        # inn_model = InnvestigateModel(self.model, lrp_exponent=1,
        #                               method="b-rule",
        #                               beta=0, epsilon=1e-6)
        #
        # inn_model.eval()
        # model_prediction, heatmap = inn_model.innvestigate(in_tensor=image)
        # model_prediction = np.argmax(model_prediction, axis=1)

        # def run_guided_backprop(net, image_tensor):
        #     return interpretation.guided_backprop(net, image_tensor, cuda=True, verbose=False, apply_softmax=False)
        #
        # def run_LRP(net, image_tensor):
        #     return inn_model.innvestigate(in_tensor=image_tensor, rel_for_class=1)
        print('hold')
Exemple #25
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                raise NotImplementedError
                #if so, which trainloader to use?
                # weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass, 12, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)
        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # mixed precision
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2'
                                           or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module,
                                  torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape,
                                           dtype=module.running_var.dtype,
                                           device=module.running_var.device),
                                requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape,
                                            dtype=module.running_var.dtype,
                                            device=module.running_var.device),
                                requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer,
                         self.architect_optimizer] = amp.initialize(
                             self.model,
                             [self.optimizer, self.architect_optimizer],
                             opt_level=self.opt_level,
                             keep_batchnorm_fp32=keep_batchnorm_fp32,
                             loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) > 1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print(
                    'currently cannot run with nn.DataParallel and optimization level',
                    self.opt_level)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        #checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1 or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(),
                                    checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.state_dict(),
                                    checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()

            if epoch >= self.args.alpha_epoch:
                search = next(iter(self.train_loaderB))
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda(
                    ), target_search.cuda()

                self.architect_optimizer.zero_grad()
                output_search = self.model(image_search)
                arch_loss = self.criterion(output_search, target_search)
                if self.use_amp:
                    with amp.scale_loss(
                            arch_loss,
                            self.architect_optimizer) as arch_scaled_loss:
                        arch_scaled_loss.backward()
                else:
                    arch_loss.backward()
                self.architect_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            #self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

            #torch.cuda.empty_cache()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())

        # Define Optimizer
        if args.optimizer == 'SGD':
            self.task_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_optimizer = torch.optim.SGD(d_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_inv_optimizer = torch.optim.SGD(f_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.c_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            self.task_optimizer = torch.optim.Adam(f_params + y_params, lr=args.lr)
            self.d_optimizer = torch.optim.Adam(d_params, lr=args.lr)
            self.d_inv_optimizer = torch.optim.Adam(f_params, lr=args.lr)
            self.c_optimizer = torch.optim.Adam(f_params+y_params, lr=args.lr)
        else:
            raise NotImplementedError

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\'+args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.train_loader, self.nclass, classes_weights_path, self.args.dataset)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.task_loss = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.domain_loss = DomainLosses(cuda=args.cuda).build_loss()
        self.ca_loss = ''

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
            if self.args.dataset == 'gtav':
                self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_task_loss = 0.0
        train_d_loss = 0.0
        train_d_inv_loss = 0.0
        self.backbone_model.train()
        self.assp_model.train()
        self.y_model.train()
        self.d_model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            if self.args.dataset == 'gtav':
                src_image,src_label = sample['image'], sample['label']
            else:
                src_image, src_label, tgt_image = sample['src_image'], sample['src_label'], sample['tgt_image']
            if self.args.cuda:
                if self.args.dataset != 'gtav':
                    src_image, src_label, tgt_image  = src_image.cuda(), src_label.cuda(), tgt_image.cuda()
                else:
                    src_image, src_label = src_image.cuda(), src_label.cuda()
            self.scheduler(self.task_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.d_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.d_inv_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.c_optimizer, i, epoch, self.best_pred)
            self.task_optimizer.zero_grad()
            self.d_optimizer.zero_grad()
            self.d_inv_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            # source image feature
            src_high_feature_0, src_low_feature = self.backbone_model(src_image)
            src_high_feature = self.assp_model(src_high_feature_0)
            src_output = F.interpolate(self.y_model(src_high_feature, src_low_feature), src_image.size()[2:], \
                                       mode='bilinear', align_corners=True)

            src_d_pred = self.d_model(src_high_feature)
            task_loss = self.task_loss(src_output, src_label)

            if self.args.dataset != 'gtav':
                # target image feature
                tgt_high_feature_0, tgt_low_feature = self.backbone_model(tgt_image)
                tgt_high_feature = self.assp_model(tgt_high_feature_0)
                tgt_output = F.interpolate(self.y_model(tgt_high_feature, tgt_low_feature), tgt_image.size()[2:], \
                                           mode='bilinear', align_corners=True)
                tgt_d_pred = self.d_model(tgt_high_feature)

                d_loss,d_acc = self.domain_loss(src_d_pred, tgt_d_pred)
                d_inv_loss, _ = self.domain_loss(tgt_d_pred, src_d_pred)
                loss = task_loss + d_loss + d_inv_loss
                loss.backward()
                self.task_optimizer.step()
                self.d_optimizer.step()
                self.d_inv_optimizer.step()
            else:
                d_acc = 0
                d_loss = torch.tensor(0.0)
                d_inv_loss = torch.tensor(0.0)
                loss = task_loss
                loss.backward()
                self.task_optimizer.step()

            train_task_loss += task_loss.item()
            train_d_loss += d_loss.item()
            train_d_inv_loss += d_inv_loss.item()
            train_loss += task_loss.item() + d_loss.item() + d_inv_loss.item()

            tbar.set_description('Train loss: %.3f t_loss: %.3f d_loss: %.3f , d_inv_loss: %.3f  d_acc: %.2f' \
                                 % (train_loss / (i + 1),train_task_loss / (i + 1),\
                                    train_d_loss / (i + 1), train_d_inv_loss / (i + 1), d_acc*100))

            self.writer.add_scalar('train/task_loss_iter', task_loss.item(), i + num_img_tr * epoch)
            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.dataset != 'gtav':
                    image = torch.cat([src_image,tgt_image],dim=0)
                    output = torch.cat([src_output,tgt_output],dim=0)
                else:
                    image = src_image
                    output = src_output
                self.summary.visualize_image(self.writer, self.args.dataset, image, src_label, output, global_step)


        self.writer.add_scalar('train/task_loss_epoch', train_task_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + src_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'backbone_model_state_dict': self.backbone_model.module.state_dict(),
                'assp_model_state_dict': self.assp_model.module.state_dict(),
                'y_model_state_dict': self.y_model.module.state_dict(),
                'd_model_state_dict': self.d_model.module.state_dict(),
                'task_optimizer': self.task_optimizer.state_dict(),
                'd_optimizer': self.d_optimizer.state_dict(),
                'd_inv_optimizer': self.d_inv_optimizer.state_dict(),
                'c_optimizer': self.c_optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            task_loss = self.task_loss(output, target)
            test_loss += task_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU,IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU

        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'backbone_model_state_dict': self.backbone_model.module.state_dict(),
                'assp_model_state_dict': self.assp_model.module.state_dict(),
                'y_model_state_dict': self.y_model.module.state_dict(),
                'd_model_state_dict': self.d_model.module.state_dict(),
                'task_optimizer': self.task_optimizer.state_dict(),
                'd_optimizer': self.d_optimizer.state_dict(),
                'd_inv_optimizer': self.d_inv_optimizer.state_dict(),
                'c_optimizer': self.c_optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, config):

        self.config = config
        self.best_pred = 0.0

        # Define Saver
        self.saver = Saver(config)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir'])
        self.writer = self.summary.create_summary()
        
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)
        
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.config['network']['backbone'],
                        output_stride=self.config['image']['out_stride'],
                        sync_bn=self.config['network']['sync_bn'],
                        freeze_bn=self.config['network']['freeze_bn'])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']},
                        {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'],
                                    weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov'])

        # Define Criterion
        # whether to use class balanced weights
        if self.config['training']['use_balanced_weights']:
            classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'],
                                            self.config['training']['epochs'], len(self.train_loader))


        # Using cuda
        if self.config['network']['use_cuda']:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint

        if self.config['training']['weights_initialization']['use_pretrained_weights']:
            if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from']))

            if self.config['network']['use_cuda']:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'])
            else:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'})

            self.config['training']['start_epoch'] = checkpoint['epoch']

            if self.config['network']['use_cuda']:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

#            if not self.config['ft']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.config['dataset']['dataset_name'], image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        #save last checkpoint
        self.saver.save_checkpoint({
            'epoch': epoch + 1,
#            'state_dict': self.model.module.state_dict(),
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_pred': self.best_pred,
        }, is_best = False, filename='checkpoint_last.pth.tar')

        #if training on a subset reshuffle the data 
        if self.config['training']['train_on_subset']['enabled']:
            self.train_loader.dataset.shuffle_dataset()    


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
#                'state_dict': self.model.module.state_dict(),
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            },  is_best = True, filename='checkpoint_best.pth.tar')
Exemple #28
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary

        # Define Dataloader
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if DEBUG:
            print("get device: ",self.device)
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        # Define network

        model = DeepLab3d(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)
        # Define Criterion
        # whether to use class balanced weights

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(ROOT_PATH, args.dataset+'_classes_weights.npy')

            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32)) ##########weight not cuda

        else:
            weight = None

        self.criterion = DiceCELoss()
        self.model, self.optimizer = model, optimizer

        

        # Define Evaluator

        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler

        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,

                                            args.epochs, len(self.train_loader))



        # Using cuda

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):

        train_loss = 0.0
        dice_loss_count = 0.0
        ce_loss_count = 0.0
        num_count = 0

        #self.model.train()
        self.model.eval()

        tbar = tqdm(self.train_loader)

        num_img_tr = len(self.train_loader)

        for i, sample in enumerate(tbar):

            image, target = sample
            if DEBUG:
                print("image, target size feed in model,", image.size(), target.size())
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()

            output = self.model(image)
            if DEBUG:
                print(output.size())
            n,c,d,w,h = output.shape
            output2 = torch.tensor( (np.zeros( (n,c,d,w,h) ) ).astype(np.float32) )
            if(output.is_cuda==True):
                output2 = output2.to(self.device)
            for mk1 in range(0,n):
                for mk2 in range(0,c): #对于每个n, c进行正则化
                    output2[mk1,mk2,:,:,:] = ( output[mk1,mk2,:,:,:] - torch.min(output[mk1,mk2,:,:,:]) ) / ( torch.max( output[mk1,mk2,:,:,:] ) - torch.min(output[mk1,mk2,:,:,:]) )
                
            loss, dice_loss, ce_loss = self.criterion(output,output2, target,self.device)

            loss.backward()

            self.optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            dice_loss_count = dice_loss_count + dice_loss.item()
            ce_loss_count = ce_loss_count + ce_loss.item()
            num_count = num_count + 1

            # Show 10 * 3 inference results each epoch

            if i % (num_img_tr // 5) == 0:

                global_step = i + num_img_tr * epoch





        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))

        print('Loss: %.3f, dice loss: %.3f, ce loss: %.3f' % (train_loss, dice_loss_count/num_count, ce_loss_count/num_count))#maybe here is something wrong



        if self.args.no_val:

            # save checkpoint every epoch

            is_best = False

            self.saver.save_checkpoint({

                'epoch': epoch + 1,

                'state_dict': self.model.module.state_dict(),

                'optimizer': self.optimizer.state_dict(),

                'best_pred': self.best_pred,

            }, is_best)





    def validation(self, epoch):

        self.model.eval()

        self.evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')

        test_loss = 0.0
        dice_loss = 0.0
        ce_loss = 0.0
        num_count = 0
        for i, sample in enumerate(tbar):
            image, target = sample
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)
            n,c,d,w,h = output.shape
            output2 = torch.tensor( (np.zeros( (n,c,d,w,h) ) ).astype(np.float32) )
            if(output.is_cuda==True):
                output2 = output2.to(self.device)
            for mk1 in range(0,n):
                for mk2 in range(0,c): #对于每个n, c进行正则化
                    output2[mk1,mk2,:,:,:] = ( output[mk1,mk2,:,:,:] - torch.min(output[mk1,mk2,:,:,:]) ) / ( torch.max( output[mk1,mk2,:,:,:] ) - torch.min(output[mk1,mk2,:,:,:]) )
                

            loss, dice, ce = self.criterion(output, output2, target, self.device)
            test_loss += loss.item()
            dice_loss += dice.item()
            ce_loss += ce.item()
            num_count += 1
            tbar.set_description('Test loss: %.3f, dice loss: %.3f, ce loss: %.3f' % (test_loss / (i + 1), dice_loss / num_count, ce_loss / num_count))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
#            if self.args.cuda:
#                target, pred = torch.from_numpy(target).cuda(), torch.from_numpy(pred).cuda()
            if DEBUG:
                print("check gt_image shape, pred img shape ",target.shape, pred.shape)
            self.evaluator.add_batch(np.squeeze(target), np.squeeze(pred))



        # Fast test during the training

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print('Validation:')

        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))

        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))

        print('Loss: %.3f' % test_loss, dice_loss, ce_loss)



        new_pred = mIoU

        if new_pred > self.best_pred:

            is_best = True

            self.best_pred = new_pred

            self.saver.save_checkpoint({

                'epoch': epoch + 1,

                'state_dict': self.model.module.state_dict(),

                'optimizer': self.optimizer.state_dict(),

                'best_pred': self.best_pred,

            }, is_best)
            print("ltt save ckpt!")
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        print(self.nclass, args.backbone, args.out_stride, args.sync_bn,
              args.freeze_bn)
        #2 resnet 16 False False

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            #image, target = sample['image'], sample['label']
            image, target = sample['trace'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['trace'], sample['label']
            #image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Exemple #30
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Generate .npy file for dataloader
        self.img_process(args)

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = getattr(modeling, args.model_name)(pretrained=args.pretrained)

        # Define Optimizer
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        # train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Criterion
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    # 将大图按unit_size的大小,每次stride的移动量进行裁剪。将分好的训练集和验证机以np数组形式存储在save_dir中,
    # 方便下次使用,并减少内存的占用。请将路径修改为自己的。
    def img_process(self, args):
        unit_size = args.base_size
        stride = unit_size  # int(unit_size/2)
        save_dir = os.path.join(
            '/data/dingyifeng/pytorch-jingwei-master/npy_process',
            str(unit_size))
        # npy_process
        if not os.path.exists(save_dir):

            Image.MAX_IMAGE_PIXELS = 100000000000
            # load train image 1
            img = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_1.png'
            )
            img = np.asarray(img)  #(50141, 47161, 4)
            anno_map = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_1_label.png'
            )
            anno_map = np.asarray(anno_map)  #(50141, 47161)

            length, width = img.shape[0], img.shape[1]
            x1, x2, y1, y2 = 0, unit_size, 0, unit_size
            Img1 = []  # 保存小图的数组
            Label1 = []  # 保存label的数组
            while (x1 < length):
                #判断横向是否越界
                if x2 > length:
                    x2, x1 = length, length - unit_size

                while (y1 < width):
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    im = img[x1:x2, y1:y2, :]
                    if 255 in im[:, :, -1]:  # 判断裁剪出来的小图中是否存在有像素点
                        Img1.append(im[:, :, 0:3])  # 添加小图
                        Label1.append(anno_map[x1:x2, y1:y2])  # 添加label

                    if y2 == width: break

                    y1 += stride
                    y2 += stride

                if x2 == length: break

                y1, y2 = 0, unit_size
                x1 += stride
                x2 += stride
            Img1 = np.array(Img1)  #(4123, 448, 448, 3)
            Label1 = np.array(Label1)  #(4123, 448, 448)

            # load train image 2
            img = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_2.png'
            )
            img = np.asarray(img)  #(50141, 47161, 4)
            anno_map = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_2_label.png'
            )
            anno_map = np.asarray(anno_map)  #(50141, 47161)

            length, width = img.shape[0], img.shape[1]
            x1, x2, y1, y2 = 0, unit_size, 0, unit_size
            Img2 = []  # 保存小图的数组
            Label2 = []  # 保存label的数组
            while (x1 < length):
                #判断横向是否越界
                if x2 > length:
                    x2, x1 = length, length - unit_size

                while (y1 < width):
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    im = img[x1:x2, y1:y2, :]
                    if 255 in im[:, :, -1]:  # 判断裁剪出来的小图中是否存在有像素点
                        Img2.append(im[:, :, 0:3])  # 添加小图
                        Label2.append(anno_map[x1:x2, y1:y2])  # 添加label

                    if y2 == width: break

                    y1 += stride
                    y2 += stride

                if x2 == length: break

                y1, y2 = 0, unit_size
                x1 += stride
                x2 += stride
            Img2 = np.array(Img2)  #(5072, 448, 448, 3)
            Label2 = np.array(Label2)  #(5072, 448, 448)

            Img = np.concatenate((Img1, Img2), axis=0)
            cat = np.concatenate((Label1, Label2), axis=0)

            # shuffle
            np.random.seed(1)
            assert (Img.shape[0] == cat.shape[0])
            shuffle_id = np.arange(Img.shape[0])
            np.random.shuffle(shuffle_id)
            Img = Img[shuffle_id]
            cat = cat[shuffle_id]

            os.mkdir(save_dir)
            print("=> generate {}".format(unit_size))
            # split train dataset
            images_train = Img  #[:int(Img.shape[0]*0.8)]
            categories_train = cat  #[:int(cat.shape[0]*0.8)]
            assert (len(images_train) == len(categories_train))
            np.save(os.path.join(save_dir, 'train_img.npy'), images_train)
            np.save(os.path.join(save_dir, 'train_label.npy'),
                    categories_train)
            # split val dataset
            images_val = Img[int(Img.shape[0] * 0.8):]
            categories_val = cat[int(cat.shape[0] * 0.8):]
            assert (len(images_val) == len(categories_val))
            np.save(os.path.join(save_dir, 'val_img.npy'), images_val)
            np.save(os.path.join(save_dir, 'val_label.npy'), categories_val)

            print("=> img_process finished!")
        else:
            print("{} file already exists!".format(unit_size))
        for x in locals().keys():
            del locals()[x]
        # 释放内存
        import gc
        gc.collect()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)