Esempio n. 1
0
def train(train_queue, valid_queue, model, architect, criterion, optimizer,
          lr):
    objs = utils.AvgrageMeter()  # 用于保存loss的值
    accs = utils.AvgrageMeter()
    MIoUs = utils.AvgrageMeter()
    fscores = utils.AvgrageMeter()

    # device = torch.device('cuda' if torch.cuda.is_avaitargetsle() else 'cpu')
    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    for step, (input, target) in enumerate(
            train_queue):  #每个step取出一个batch,batchsize是64(256个数据对)
        model.train()
        n = input.size(0)

        input = input.to(device)
        target = target.to(device)

        # get a random minibatch from the search queue with replacement
        input_search, target_search = next(iter(valid_queue))
        input_search = input_search.to(device)
        target_search = target_search.to(device)

        architect.step(input,
                       target,
                       input_search,
                       target_search,
                       lr,
                       optimizer,
                       unrolled=args.unrolled)

        optimizer.zero_grad()
        logits = model(input)
        logits = logits.to(device)
        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
        optimizer.step()

        #prec = utils.Accuracy(logits, target)
        #prec1 = utils.MIoU(logits, target, dataset_classes)
        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
Esempio n. 2
0
def val_seg(model,
            dataLoader,
            epoch,
            loss_fn,
            num_classes,
            logger,
            tensorLogger,
            device='cuda',
            args=None):
    model.eval()
    logger.info("Valid | [{:2d}/{}]".format(epoch + 1, args.max_epoch))
    losses = AverageMeter()
    batch_time = AverageMeter()

    evaluator = Evaluator(num_class=num_classes)
    evaluator.reset()
    with torch.no_grad():
        for i, (inputs, target) in enumerate(dataLoader):
            inputs = inputs.to(device=device)
            target = target.to(device=device)

            initTime = time.time()
            output = model(inputs)

            loss = loss_fn(output, target)

            output_np = output.detach().cpu().numpy()
            target_np = target.detach().cpu().numpy()

            # print(output_np.shape, target_np.shape)
            evaluator.add_batch(target_np, np.argmax(output_np, axis=1))
            losses.update(loss.item(), inputs.size(0))

            batch_time.update(time.time() - initTime)

            if i % 100 == 0:  # print after every 100 batches
                logger.info(
                    "Valid | {:2d} | [{:4d}/{}] Infer:{:.2f}sec |  Loss:{:.4f}  |  Miou:{:4f}  |"
                    .format(epoch + 1, i + 1, len(dataLoader), batch_time.avg,
                            losses.avg,
                            evaluator.Mean_Intersection_over_Union()[0]))
    Totalmiou = evaluator.Mean_Intersection_over_Union()[0]
    tensorLogger.add_scalar('val/loss', losses.avg, epoch + 1)
    tensorLogger.add_scalar('val/miou', Totalmiou, epoch + 1)

    return losses.avg, Totalmiou
def validation(epoch, model, args, criterion, nclass, test_tag=False):
    model.eval()

    losses = 0.0

    evaluator = Evaluator(nclass)
    evaluator.reset()
    if test_tag == True:
        num_img = args.data_dict['num_valid']
    else:
        num_img = args.data_dict['num_test']
    for i in range(num_img):
        if test_tag == True:
            inputs = torch.FloatTensor(args.data_dict['valid_data'][i]).cuda()
            target = torch.FloatTensor(args.data_dict['valid_mask'][i]).cuda()
        else:
            inputs = torch.FloatTensor(args.data_dict['test_data'][i]).cuda()
            target = torch.FloatTensor(args.data_dict['test_mask'][i]).cuda()

        with torch.no_grad():
            output = model(inputs)
        loss_val = criterion(output, target)
        print('epoch: {0}\t'
              'iter: {1}/{2}\t'
              'loss: {loss:.4f}'.format(epoch + 1,
                                        i + 1,
                                        args.data_dict['num_train'],
                                        loss=loss_val))
        pred = output.data.cpu().numpy()
        target = target.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        evaluator.add_batch(target, pred)

        losses += loss_val

        if test_tag == True:
            #save input,target,pred
            pred_save_dir = './pred/'
            sitk.WriteImage(sitk.GetImageFromArray(inputs),
                            pred_save_dir + 'input_{}.nii.gz'.format(i))
            sitk.WriteImage(sitk.GetImageFromArray(target),
                            pred_save_dir + 'target_{}.nii.gz'.format(i))
            sitk.WriteImage(
                sitk.GetImageFromArray(pred),
                pred_save_dir + 'pred_{}_{}.nii.gz'.format(i, epoch))

    Acc = evaluator.Pixel_Accuracy()
    Acc_class = evaluator.Pixel_Accuracy_Class()
    mIoU = evaluator.Mean_Intersection_over_Union()
    FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
    if test_tag == True:
        print('Test:')
    else:
        print('Validation:')
    print('[Epoch: %d, numImages: %5d]' % (epoch, num_img))
    print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
        Acc, Acc_class, mIoU, FWIoU))
    print('Loss: %.3f' % losses)
Esempio n. 4
0
    def run(self, img_tensor, target, verbose=0, save_to=None):
        """Run metric on one image-saliency pair.
        Modified from: https://github.com/eclique/RISE/blob/master/evaluation.py

        Args:
            img_tensor (Tensor): normalized image tensor.
            explanation (np.ndarray): mean output. 
            verbose (int): in [0, 1, 2].
                0 - return list of scores.
                1 - also plot final step.
                2 - also plot every step and print 2 top classes.
            save_to (str): directory to save every step plots to.
        Return:
            scores (nd.array): Array containing scores at every step.
        """
        # Get the prediction pixels by taking the max logit across the 21 classes
        n_samples = img_tensor.shape[0]
        pred, explanation = torch.max(self.model(img_tensor.cuda()), (1))
        evaluator = Evaluator(21)
        # The number of steps is pixel by pixel, TODO is to change this to a set of pixels grouped around the chosen pixel
        n_steps = (
            HW + self.step - 1
        ) // self.step  # HW is the area of the images, they made it a global variable....

        start = img_tensor.clone()  # original input
        finish = self.substrate_fn(img_tensor)  # aim to end with all 0's
        # miou = np.empty(n_steps + 1)
        miou = np.empty((n_samples, n_steps + 1))

        # Coordinates of pixels in order of decreasing saliency
        # orders the pixels from most important to least by providing the index
        salient_order = np.flip(np.argsort(explanation.reshape(
            -1, HW).detach().cpu().numpy(),
                                           axis=1),
                                axis=-1)  # the indexes of proper order
        r = np.arange(n_samples).reshape(n_samples, 1)
        for i in tqdm(range(n_steps + 1), desc='Deleting Pixels'):
            pred, explanation = torch.max(self.model(start.cuda()), (1))
            for j in range(n_samples):
                evaluator.add_batch(target[j].detach().cpu().numpy(),
                                    explanation[j].detach().cpu().numpy())
                score = evaluator.Mean_Intersection_over_Union()
                evaluator.reset()
                miou[j][i] = score
            if i < n_steps:
                # coords = salient_order[:, self.step * i:self.step * (i + 1)]
                # start.cpu().numpy().reshape(1, 3, HW)[0, :, coords] = finish.cpu().numpy().reshape(1, 3, HW)[0, :, coords]
                coords = salient_order[:, self.step * i:self.step * (i + 1)]
                start.cpu().numpy().reshape(
                    n_samples, 3, HW)[r, :,
                                      coords] = finish.cpu().numpy().reshape(
                                          n_samples, 3, HW)[r, :, coords]
        robust_loss = [self.get_gradients(miou[x]) for x in range(n_samples)]
        return robust_loss
Esempio n. 5
0
    def train_one_epoch(self, epoch):

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        mious = AverageMeter()
        fscores = AverageMeter()

        epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1,
                                                  self.run_config.epochs)
        iter_per_epoch = len(self.run_config.train_loader)
        end = time.time()
        for i, (datas, targets) in enumerate(self.run_config.train_loader):
            if torch.cuda.is_available():
                datas = datas.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)
            else:
                raise ValueError('do not support cpu version')
            data_time.update(time.time() - end)
            logits = self.model(datas)
            loss = self.criterion(logits, targets)
            evaluator = Evaluator(self.run_config.nb_classes)
            evaluator.add_batch(targets, logits)
            acc = evaluator.Pixel_Accuracy()
            miou = evaluator.Mean_Intersection_over_Union()
            fscore = evaluator.Fx_Score()
            # metrics update
            losses.update(loss.data.item(), datas.size(0))
            accs.update(acc.item(), datas.size(0))
            mious.update(miou.item(), datas.size(0))
            fscores.update(fscore.item(), datas.size(0))
            # update parameters
            self.model.zero_grad()
            loss.backward()
            self.optimizer.step()  # only update network_weight_parameters
            # elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            # within one epoch, per iteration print train_log
            if (i + 1) % self.run_config.train_print_freq == 0 or (
                    i + 1) == len(self.run_config.train_loader):
                #print(i+1, self.run_config.train_print_freq)
                Wstr = '|*TRAIN*|' + time_string(
                ) + '[{:}][iter{:03d}/{:03d}]'.format(epoch_str, i + 1,
                                                      iter_per_epoch)
                Tstr = '|Time   | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base   | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]' \
                    .format(loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr,
                                mode='retrain')
        return losses.avg, accs.avg, mious.avg, fscores.avg
Esempio n. 6
0
File: utils.py Progetto: wtsitp/DFQ
def forward_all(net_inference, dataloader, visualize=False, opt=None):
    evaluator = Evaluator(21)
    evaluator.reset()
    with torch.no_grad():
        for ii, sample in enumerate(dataloader):
            image, label = sample['image'].cuda(), sample['label'].cuda()

            activations = net_inference(image)

            image = image.cpu().numpy()
            label = label.cpu().numpy().astype(np.uint8)

            logits = activations[list(activations.keys(
            ))[-1]] if type(activations) != torch.Tensor else activations
            pred = torch.max(logits, 1)[1].cpu().numpy().astype(np.uint8)

            evaluator.add_batch(label, pred)

            # print(label.shape, pred.shape)
            if visualize:
                for jj in range(sample["image"].size()[0]):
                    segmap_label = decode_segmap(label[jj], dataset='pascal')
                    segmap_pred = decode_segmap(pred[jj], dataset='pascal')

                    img_tmp = np.transpose(image[jj], axes=[1, 2, 0])
                    img_tmp *= (0.229, 0.224, 0.225)
                    img_tmp += (0.485, 0.456, 0.406)
                    img_tmp *= 255.0
                    img_tmp = img_tmp.astype(np.uint8)

                    cv2.imshow('image', img_tmp[:, :, [2, 1, 0]])
                    cv2.imshow('gt', segmap_label)
                    cv2.imshow('pred', segmap_pred)
                    cv2.waitKey(0)

    Acc = evaluator.Pixel_Accuracy()
    Acc_class = evaluator.Pixel_Accuracy_Class()
    mIoU = evaluator.Mean_Intersection_over_Union()
    FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
    print("Acc: {}".format(Acc))
    print("Acc_class: {}".format(Acc_class))
    print("mIoU: {}".format(mIoU))
    print("FWIoU: {}".format(FWIoU))
    if opt is not None:
        with open("seg_result.txt", 'a+') as ww:
            ww.write(
                "{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n"
                .format(opt.dataset, opt.quantize, opt.relu, opt.equalize,
                        opt.absorption, opt.correction, opt.clip_weight,
                        opt.distill_range))
            ww.write("Acc: {}, Acc_class: {}, mIoU: {}, FWIoU: {}\n\n".format(
                Acc, Acc_class, mIoU, FWIoU))
Esempio n. 7
0
def train(train_queue, model, criterion, optimizer):
    objs = utils.AvgrageMeter()
    accs = utils.AvgrageMeter()
    MIoUs = utils.AvgrageMeter()
    fscores = utils.AvgrageMeter()
    model.train()

    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    for step, (input, target) in enumerate(train_queue):
        input = input.to(device)
        target = target.to(device)
        n = input.size(0)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)

        if args.auxiliary:
            loss_aux = criterion(logits_aux, target)
            loss += args.auxiliary_weight * loss_aux
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
        optimizer.step()

        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
Esempio n. 8
0
class Eval(object):
    def __init__(self, args):
        self.args = args
        self.evaluator = Evaluator(args.nclass)
        self.loader = Loader(args)

    def evaluation(self):
        self.evaluator.reset()
        tbar = tqdm(self.loader)
        for i, sample in enumerate(tbar):
            names = sample['name']
            preds = sample['pred']
            labels = sample['label']
            self.evaluator.add_batch(labels, preds)

        miou = self.evaluator.Mean_Intersection_over_Union()
        fwiou = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("mIoU:", miou)
        print("fwIoU:", fwiou)
Esempio n. 9
0
def infer(test_queue, model, criterion):
    objs = util.AvgrageMeter()
    accs = util.AvgrageMeter()
    MIoUs = util.AvgrageMeter()
    fscores = util.AvgrageMeter()
    model.eval()

    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    save_path = args.model_path[:-10] + 'predict'
    print(save_path)
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    for step, (input, target, data_list) in enumerate(test_queue):
        input = input.to(device)
        target = target.to(device)
        n = input.size(0)

        logits = model(input)
        util.save_pred_WHU(logits, save_path, data_list)

        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)

        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('test %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
Esempio n. 10
0
def test(model_path):
    args = makeargs()
    kwargs = {'num_workers': args.workers, 'pin_memory': True}
    train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs)
    print('Loading model...')
    model = DeepLab(num_classes=8, backbone='drn', output_stride=args.output_stride,
                    sync_bn=args.sync_bn, freeze_bn=args.freeze_bn)
    model.eval()
    checkpoint = torch.load(model_path)
    model = model.cuda()
    model.load_state_dict(checkpoint['state_dict'])
    print('Done')
    criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
    evaluator = Evaluator(nclass)
    evaluator.reset()

    print('Model infering')
    test_dir = 'test_example1'
    test_loss = 0.0
    tbar = tqdm(test_loader, desc='\r')
    for i, sample in enumerate(tbar):
        image, target = sample['image'], sample['label']
        image, target = image.cuda(), target.cuda()

        with torch.no_grad():  #
            output = model(image)
        loss = criterion(output, target)
        test_loss += loss.item()
        tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
        pred = output.data.cpu().numpy()
        target = target.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        evaluator.add_batch(target, pred)


    print(image.shape)
    Acc = evaluator.Pixel_Accuracy()
    mIoU = evaluator.Mean_Intersection_over_Union()
    print('testing:')
    print("Acc:{}, mIoU:{},".format(Acc, mIoU))
    print('Loss: %.3f' % test_loss)
Esempio n. 11
0
def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    accs = utils.AvgrageMeter()
    MIoUs = utils.AvgrageMeter()
    fscores = utils.AvgrageMeter()
    model.eval()

    # device = torch.device(torch.cuda if torch.cuda.is_avaitargetsle() else cpu)
    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    for step, (input, target) in enumerate(valid_queue):

        input = input.to(device)
        target = target.to(device)

        logits = model(input)
        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)

        #prec = utils.Accuracy(logits, target)
        #prec1 = utils.MIoU(logits, target, dataset_classes)
        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        n = input.size(0)

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('valid %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
Esempio n. 12
0
def main():
    args = parse_args()

    if args.dataset == 'CamVid':
        num_class = 32
    elif args.dataset == 'Cityscapes':
        num_class = 19

    if args.net == 'resnet101':
        blocks = [2, 4, 23, 3]
        model = FPN(blocks, num_class, back_bone=args.net)

    if args.checkname is None:
        args.checkname = 'fpn-' + str(args.net)

    evaluator = Evaluator(num_class)

    # Trained model path and name
    experiment_dir = args.experiment_dir
    load_name = os.path.join(experiment_dir, 'checkpoint.pth.tar')

    # Load trained model
    if not os.path.isfile(load_name):
        raise RuntimeError("=> no checkpoint found at '{}'".format(load_name))
    print('====>loading trained model from ' + load_name)
    checkpoint = torch.load(load_name)
    checkepoch = checkpoint['epoch']
    if args.cuda:
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint['state_dict'])

    # Load image and save in test_imgs
    test_imgs = []
    test_label = []
    if args.dataset == "CamVid":
        root_dir = Path.db_root_dir('CamVid')
        test_file = os.path.join(root_dir, "val.csv")
        test_data = CamVidDataset(csv_file=test_file, phase='val')
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    elif args.dataset == "Cityscapes":
        kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
        #_, test_loader, _, _ = make_data_loader(args, **kwargs)
        _, val_loader, test_loader, _ = make_data_loader(args, **kwargs)
    else:
        raise RuntimeError("dataset {} not found.".format(args.dataset))

    # test
    Acc = []
    Acc_class = []
    mIoU = []
    FWIoU = []
    results = []
    for iter, batch in enumerate(val_loader):
        if args.dataset == 'CamVid':
            image, target = batch['X'], batch['l']
        elif args.dataset == 'Cityscapes':
            image, target = batch['image'], batch['label']
        else:
            raise NotImplementedError

        if args.cuda:
            image, target, model = image.cuda(), target.cuda(), model.cuda()
        with torch.no_grad():
            output = model(image)
        pred = output.data.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        target = target.cpu().numpy()
        evaluator.add_batch(target, pred)

        # show result
        pred_rgb = decode_seg_map_sequence(pred, args.dataset, args.plot)
        results.append(pred_rgb)

    Acc = evaluator.Pixel_Accuracy()
    Acc_class = evaluator.Pixel_Accuracy_Class()
    mIoU = evaluator.Mean_Intersection_over_Union()
    FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()

    print('Mean evaluate result on dataset {}'.format(args.dataset))
    print('Acc:{:.3f}\tAcc_class:{:.3f}\nmIoU:{:.3f}\tFWIoU:{:.3f}'.format(Acc, Acc_class, mIoU, FWIoU))
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        if args.dataset == 'CamVid':
            size = 512
            train_file = os.path.join(os.getcwd() + "\\data\\CamVid", "train.csv")
            val_file = os.path.join(os.getcwd() + "\\data\\CamVid", "val.csv")
            print('=>loading datasets')
            train_data = CamVidDataset(csv_file=train_file, phase='train')
            self.train_loader = torch.utils.data.DataLoader(train_data,
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.num_workers)
            val_data = CamVidDataset(csv_file=val_file, phase='val', flip_rate=0)
            self.val_loader = torch.utils.data.DataLoader(val_data,
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.num_workers)
            self.num_class = 32
        elif args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2,4,23,3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'CamVid':
            self.criterion = nn.CrossEntropyLoss()
        elif args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume:
            output_dir = os.path.join(args.save_dir, args.dataset, args.checkname)
            runs = sorted(glob.glob(os.path.join(output_dir, 'experiment_*')))
            run_id = int(runs[-1].split('_')[-1]) - 1 if runs else 0
            experiment_dir = os.path.join(output_dir, 'experiment_{}'.format(str(run_id)))
            load_name = os.path.join(experiment_dir,
                                 'checkpoint.pth.tar')
            if not os.path.isfile(load_name):
                raise RuntimeError("=> no checkpoint found at '{}'".format(load_name))
            checkpoint = torch.load(load_name)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.lr = checkpoint['optimizer']['param_groups'][0]['lr']
            print("=> loaded checkpoint '{}'(epoch {})".format(load_name, checkpoint['epoch']))

        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        # tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        if self.lr_staget_ind > 1 and epoch % (self.lr_stage[self.lr_staget_ind]) == 0:
            adjust_learning_rate(self.optimizer, self.args.lr_decay_gamma)
            self.lr *= self.args.lr_decay_gamma
            self.lr_staget_ind += 1
        for iteration, batch in enumerate(self.train_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            inputs = Variable(image)
            labels = Variable(target)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels.long())
            loss_val = loss.item()
            loss.backward(torch.ones_like(loss))
            # loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            # tbar.set_description('\rTrain loss:%.3f' % (train_loss / (iteration + 1)))

            if iteration % 10 == 0:
                print("Epoch[{}]({}/{}):Loss:{:.4f}, learning rate={}".format(epoch, iteration, len(self.train_loader), loss.data, self.lr))

            self.writer.add_scalar('train/total_loss_iter', loss.item(), iteration + num_img_tr * epoch)

            #if iteration % (num_img_tr // 10) == 0:
            #    global_step = iteration + num_img_tr * epoch
            #    self.summary.visualize_image(self.witer, self.args.dataset, image, target, outputs, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for iter, batch in enumerate(self.val_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f ' % (test_loss / (iter + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/FWIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iter * self.args.batch_size + image.shape[0]))
        print("Acc:{:.5f}, Acc_class:{:.5f}, mIoU:{:.5f}, fwIoU:{:.5f}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Test:
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='val',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city,
                                     split='val',
                                     img_root=config.img_root)
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            train = spacenet.Spacenet(city=city,
                                      split='train',
                                      img_root=config.img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())

        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

    def save_output(module, input, output):
        global activation, i
        # save output
        print('I came here')
        channels = output.permute(1, 0, 2, 3)
        c = channels.shape[0]
        features = channels.reshape(c, -1)
        if len(activation) == i:
            activation.append(features)
        else:
            activation[i] = torch.cat([activation[i], features], dim=1)
        i += 1
        return

    def get_performance(self, dataloader, trainloader, city):

        if 1:
            pix = torch.load("pix" + self.config.dataset + "_" + city + ".pt")
            for k in range(0, len(pix)):
                fig = plt.figure()
                plt.hist(pix[k])
                plt.xlabel('Activation values')
                plt.ylabel("Count")
                #plt.legend()
                fig.savefig('./train_log/figs/pix_' + self.config.dataset +
                            '_' + city + 'act' + str(k) + '.png')
            return [], [], []

        # change mean and var of bn to adapt to the target domain
        if self.bn and city != self.config.dataset:
            self.checkpoint = torch.load('./train_log/' + self.config.dataset +
                                         '_da_' + city + '.pth')
            self.model.load_state_dict(self.checkpoint)
            if self.cuda:
                self.model = self.model.cuda()

        if 0:  #self.bn and city != self.config.dataset:
            print('BN Adaptation on' + city)
            self.model.train()
            tbar = tqdm(dataloader, desc='\r')
            for sample in trainloader:
                image, target = sample['image'], sample['label']
                if self.cuda:
                    image, target = image.cuda(), target.cuda()
                with torch.no_grad():
                    output = self.model(image)

        batch = self.save_batch
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(dataloader, desc='\r')
        # save in different directories
        if self.bn:
            save_path = os.path.join(self.save_path, city + '_bn')
        else:
            save_path = os.path.join(self.save_path, city)

        global randCh
        #randCh={};
        global first
        first = 1
        global ii
        ii = 0
        global ncity
        ncity = city
        randCh = torch.load('randCh.pt')
        #if city != self.config.dataset:
        #randCh = torch.load('randCh.pt')
        #else:
        #   randCh={};
        layr = 0
        aa = 0
        if city == self.config.dataset:
            for hh in self.model.modules():
                if isinstance(hh, nn.ReLU6):  #Conv2d):
                    layr += 1
                    if layr % 5 == 0:
                        #if first==1 and city == self.config.dataset:
                        #randCh[aa] =  np.random.randint(hh.out_channels)
                        #print(layr)
                        hh.register_forward_hook(save_output2)
                        aa += 1

    # evaluate on the test dataset
        pix = {}
        pix1 = {}
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

                #manipulate activations here
                for k in activation.keys():
                    if first == 1:
                        pix[k] = []
                        pix1[k] = []
                    for row in range(0, activation[k].shape[1]):
                        actkrow = activation[k][:, row, :-1].reshape(
                            -1).cpu().numpy()
                        pix[k] = np.hstack((pix[k], actkrow))
                        actkrow1 = activation[k][:, row,
                                                 1:].reshape(-1).cpu().numpy()
                        pix1[k] = np.hstack((pix1[k], actkrow1))
                    for bb in range(0, activation[k].size(0)):
                        cv2.imwrite(
                            os.path.join(
                                save_path, 'act' + str(k) + 'im' + str(i) +
                                'b' + str(bb) + city + '.jpg'),
                            activation[k][bb, :].cpu().numpy() * 255)

                first += 1
                ii = 0

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # save pictures
            if batch > 0:
                if not os.path.exists(self.save_path):
                    os.mkdir(self.save_path)
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                image = image.cpu().numpy() * 255
                image = image.transpose(0, 2, 3, 1).astype(int)

                imgs = self.color_images(pred, target)
                self.save_images(imgs, batch, save_path, False)
                self.save_images(image, batch, save_path, True)
                batch -= 1

        corrVal = {}
        for k in activation.keys():
            corrVal[k] = np.corrcoef(pix[k], pix1[k])[0, 1]
            #_ = plt.hist(pix[k]); plt.show()
        torch.save(pix, "pix" + self.config.dataset + "_" + city + ".pt")
        print(corrVal)

        Acc = self.evaluator.Building_Acc()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        return Acc, IoU, mIoU

    def test(self):
        A, I, Im = self.get_performance(self.source_loader, None,
                                        self.config.dataset)
        tA, tI, tIm = [], [], []
        for dl, tl, city in zip(self.target_loader, self.target_trainloader,
                                self.target):
            tA_, tI_, tIm_ = self.get_performance(dl, tl, city)
            tA.append(tA_)
            tI.append(tI_)
            tIm.append(tIm_)

        res = {}
        print("Test for source domain:")
        print("{}: Acc:{}, IoU:{}, mIoU:{}".format(self.config.dataset, A, I,
                                                   Im))
        res[config.dataset] = {'Acc': A, 'IoU': I, 'mIoU': Im}

        print('Test for target domain:')
        for i, city in enumerate(self.target):
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(city, tA[i], tI[i],
                                                       tIm[i]))
            res[city] = {'Acc': tA[i], 'IoU': tI[i], 'mIoU': tIm[i]}

        if self.bn:
            name = 'train_log/test_bn.json'
        else:
            name = 'train_log/test.json'

        with open(name, 'w') as f:
            json.dump(res, f)

    def save_images(self, imgs, batch_index, save_path, if_original=False):
        for i, img in enumerate(imgs):
            img = img[:, :, ::-1]  # change to BGR
            #from IPython import embed
            #embed()
            if not if_original:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Original.jpg'),
                    img)
            else:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Pred.jpg'), img)

    def color_images(self, pred, target):
        imgs = []
        for p, t in zip(pred, target):
            tmp = p * 2 + t
            np.squeeze(tmp)
            img = np.zeros((p.shape[0], p.shape[1], 3))
            # bkg:negative, building:postive
            #from IPython import embed
            #embed()
            img[np.where(tmp == 0)] = [0, 0, 0]  # Black RGB, for true negative
            img[np.where(tmp == 1)] = [255, 0,
                                       0]  # Red RGB, for false negative
            img[np.where(tmp == 2)] = [0, 255,
                                       0]  # Green RGB, for false positive
            img[np.where(tmp == 3)] = [255, 255,
                                       0]  #Yellow RGB, for true positive
            imgs.append(img)
        return imgs
Esempio n. 15
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset)[0], args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if os.path.isfile(args.resume):
                checkpoint = torch.load(args.resume,map_location=torch.device('cpu'))
                args.start_epoch = checkpoint['epoch']
                if args.cuda:
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])
                if not args.ft:
                    self.optimizer.load_state_dict(checkpoint['optimizer'])
                # self.best_pred = checkpoint['best_pred']-0.3
                print("=> loaded checkpoint '{}' (epoch {})"
                    .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target,weight = sample['image'], sample['label'],sample['weight']
            if self.args.cuda:
                image, target,weight= image.cuda(), target.cuda(),weight.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = 0
            for index in range(output.shape[0]):
                temp1 = output[index].unsqueeze(0)
                temp2 = target[index].unsqueeze(0)
                loss = loss + weight[index,0,0]*self.criterion(temp1,temp2)
            loss.backward() 
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
        # self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)



    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target= sample['image'], sample['label']#, sample['weight']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)


            # for channels in range(target.shape[0]):
            #     imagex = image[channels].cpu().numpy()
            #     imagex = np.transpose(imagex,(1,2,0))
            #     pre = pred[channels]
            #     targ = target[channels]

            #     plt.subplot(131)
            #     plt.imshow(imagex)

            #     plt.subplot(132)
            #     image1 = imagex.copy()
            #     for i in [0,1] :
            #         g = image1[:,:,i]
            #         g[pre>0.5] = 255
            #         image1[:,:,i] = g
            #     for i in [2]:
            #         g = image1[:,:,i]
            #         g[pre>0.5] = 0
            #         image1[:,:,i] = g
            #     plt.imshow(image1)

            #     plt.subplot(133)
            #     image2 = imagex.copy()
            #     for i in [0,1] :
            #         g = image2[:,:,i]
            #         g[targ>0.5] = 255
            #         image2[:,:,i] = g
            #     for i in [2]:
            #         g = image2[:,:,i]
            #         g[targ>0.5] = 0
            #         image2[:,:,i] = g
            #     plt.imshow(image2)

            #     plt.show()


            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        xy_mIoU = self.evaluator.xy_Mean_Intersection_over_Union()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.test_batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print("min_mIoU{}".format(xy_mIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = xy_mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        print(self.nclass, args.backbone, args.out_stride, args.sync_bn,
              args.freeze_bn)
        #2 resnet 16 False False

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            #image, target = sample['image'], sample['label']
            image, target = sample['trace'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['trace'], sample['label']
            #image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Esempio n. 17
0
class Trainer(object):
    def __init__(self, config):

        self.config = config
        self.best_pred = 0.0

        # Define Saver
        self.saver = Saver(config)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir'])
        self.writer = self.summary.create_summary()
        
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)
        
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.config['network']['backbone'],
                        output_stride=self.config['image']['out_stride'],
                        sync_bn=self.config['network']['sync_bn'],
                        freeze_bn=self.config['network']['freeze_bn'])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']},
                        {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'],
                                    weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov'])

        # Define Criterion
        # whether to use class balanced weights
        if self.config['training']['use_balanced_weights']:
            classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'],
                                            self.config['training']['epochs'], len(self.train_loader))


        # Using cuda
        if self.config['network']['use_cuda']:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint

        if self.config['training']['weights_initialization']['use_pretrained_weights']:
            if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from']))

            if self.config['network']['use_cuda']:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'])
            else:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'})

            self.config['training']['start_epoch'] = checkpoint['epoch']

            if self.config['network']['use_cuda']:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

#            if not self.config['ft']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.config['dataset']['dataset_name'], image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        #save last checkpoint
        self.saver.save_checkpoint({
            'epoch': epoch + 1,
#            'state_dict': self.model.module.state_dict(),
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_pred': self.best_pred,
        }, is_best = False, filename='checkpoint_last.pth.tar')

        #if training on a subset reshuffle the data 
        if self.config['training']['train_on_subset']['enabled']:
            self.train_loader.dataset.shuffle_dataset()    


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
#                'state_dict': self.model.module.state_dict(),
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            },  is_best = True, filename='checkpoint_best.pth.tar')
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion

        self.criterion = SegmentationLosses(cuda=args.cuda)
        self.model, self.optimizer = model, optimizer
        self.contexts = TemporalContexts(history_len=5)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning or in validation/test mode
        if args.ft or args.mode == "val" or args.mode == "test":
            args.start_epoch = 0
            self.best_pred = 0.0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)

        for i, sample in enumerate(tbar):
            image, region_prop, target = sample['image'], sample['rp'], sample[
                'label']
            if self.args.cuda:
                image, region_prop, target = image.cuda(), region_prop.cuda(
                ), target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image, region_prop)
            loss = self.criterion.CrossEntropyLoss(
                output,
                target,
                weight=torch.from_numpy(
                    calculate_weights_batch(sample,
                                            self.nclass).astype(np.float32)))
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            pred = output.clone().data.cpu()
            pred_softmax = F.softmax(pred, dim=1).numpy()
            pred = np.argmax(pred.numpy(), axis=1)

            # Plot prediction every 20th iter
            if i % (num_img_tr // 20) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.vis_grid(self.writer,
                                      self.args.dataset,
                                      image.data.cpu().numpy()[0],
                                      target.data.cpu().numpy()[0],
                                      pred[0],
                                      region_prop.data.cpu().numpy()[0],
                                      pred_softmax[0],
                                      global_step,
                                      split="Train")

        self.writer.add_scalar('train/total_loss_epoch',
                               train_loss / num_img_tr, epoch)
        print('Loss: {}'.format(train_loss / num_img_tr))

        if self.args.no_val or self.args.save_all:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                },
                is_best,
                filename='checkpoint_' + str(epoch + 1) + '_.pth.tar')

    def validation(self, epoch):

        if self.args.mode == "train" or self.args.mode == "val":
            loader = self.val_loader
        elif self.args.mode == "test":
            loader = self.test_loader

        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(loader, desc='\r')

        test_loss = 0.0
        idr_thresholds = [0.20, 0.30, 0.40, 0.50, 0.60, 0.65]

        num_itr = len(loader)

        for i, sample in enumerate(tbar):
            image, region_prop, target = sample['image'], sample['rp'], sample[
                'label']
            # orig_region_prop = region_prop.clone()
            # region_prop = self.contexts.temporal_prop(image.numpy(),region_prop.numpy())

            if self.args.cuda:
                image, region_prop, target = image.cuda(), region_prop.cuda(
                ), target.cuda()
            with torch.no_grad():
                output = self.model(image, region_prop)

            # loss = self.criterion.CrossEntropyLoss(output,target,weight=torch.from_numpy(calculate_weights_batch(sample,self.nclass).astype(np.float32)))
            # test_loss += loss.item()
            # tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))

            output = output.detach().data.cpu()
            pred_softmax = F.softmax(output, dim=1).numpy()
            pred = np.argmax(pred_softmax, axis=1)
            target = target.cpu().numpy()
            image = image.cpu().numpy()
            region_prop = region_prop.cpu().numpy()
            # orig_region_prop = orig_region_prop.numpy()

            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # Append buffer with original context(before temporal propagation)
            # self.contexts.append_buffer(image[0],orig_region_prop[0],pred[0])

            global_step = i + num_itr * epoch
            self.summary.vis_grid(self.writer,
                                  self.args.dataset,
                                  image[0],
                                  target[0],
                                  pred[0],
                                  region_prop[0],
                                  pred_softmax[0],
                                  global_step,
                                  split="Validation")

        # Fast test during the training
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        recall, precision = self.evaluator.pdr_metric(class_id=2)
        idr_avg = np.array([
            self.evaluator.get_idr(class_value=2, threshold=value)
            for value in idr_thresholds
        ])
        false_idr = self.evaluator.get_false_idr(class_value=2)
        instance_iou = self.evaluator.get_instance_iou(threshold=0.20,
                                                       class_value=2)

        # self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Recall/per_epoch', recall, epoch)
        self.writer.add_scalar('IDR/per_epoch(0.20)', idr_avg[0], epoch)
        self.writer.add_scalar('IDR/avg_epoch', np.mean(idr_avg), epoch)
        self.writer.add_scalar('False_IDR/epoch', false_idr, epoch)
        self.writer.add_scalar('Instance_IOU/epoch', instance_iou, epoch)
        self.writer.add_histogram(
            'Prediction_hist',
            self.evaluator.pred_labels[self.evaluator.gt_labels == 2], epoch)

        print('Validation:')
        # print('Loss: %.3f' % test_loss)
        # print('Recall/PDR:{}'.format(recall))
        print('IDR:{}'.format(idr_avg[0]))
        print('False Positive Rate: {}'.format(false_idr))
        print('Instance_IOU: {}'.format(instance_iou))

        if self.args.mode == "train":
            new_pred = mIoU
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)

        else:
            pass
Esempio n. 19
0
class Trainer(object):
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.visdom = args.visdom
        if args.visdom:
            self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888)
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(
            config)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)

        self.D = Discriminator(num_classes=self.nclass, ndf=16)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': config.lr * config.lr_ratio
        }]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)
        self.D_optimizer = torch.optim.Adam(self.D.parameters(),
                                            lr=config.lr,
                                            betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.entropy_mini_loss = MinimizeEntropyLoss()
        self.bottleneck_loss = BottleneckLoss()
        self.instance_loss = InstanceLoss()
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')
        # labels for adversarial training
        self.source_label = 0
        self.target_label = 1

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

            self.D = torch.nn.DataParallel(self.D)
            patch_replication_callback(self.D)
            self.D = self.D.cuda()

        self.best_pred_source = 0.0
        self.best_pred_target = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

    def training(self, epoch):
        train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum, adv_loss_sum, d_loss_sum, ins_loss_sum = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        self.model.train()
        if config.freeze_bn:
            self.model.module.freeze_bn()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        target_train_iterator = iter(self.target_train_loader)
        for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            #if self.visdom:
            #    self.vis.line(X=torch.tensor([itr]), Y=torch.tensor([self.optimizer.param_groups[0]['lr']]),
            #              win='lr', opts=dict(title='lr', xlabel='iter', ylabel='lr'),
            #              update='append' if itr>0 else None)
            self.summary.writer.add_scalar(
                'Train/lr', self.optimizer.param_groups[0]['lr'], itr)
            A_image, A_target = sample['image'], sample['label']

            # Get one batch from target domain
            try:
                target_sample = next(target_train_iterator)
            except StopIteration:
                target_train_iterator = iter(self.target_train_loader)
                target_sample = next(target_train_iterator)

            B_image, B_target, B_image_pair = target_sample[
                'image'], target_sample['label'], target_sample['image_pair']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()
                B_image, B_target, B_image_pair = B_image.cuda(
                ), B_target.cuda(), B_image_pair.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target, self.config.lr_ratio)
            self.scheduler(self.D_optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target, self.config.lr_ratio)

            A_output, A_feat, A_low_feat = self.model(A_image)
            B_output, B_feat, B_low_feat = self.model(B_image)
            #B_output_pair, B_feat_pair, B_low_feat_pair = self.model(B_image_pair)
            #B_output_pair, B_feat_pair, B_low_feat_pair = flip(B_output_pair, dim=-1), flip(B_feat_pair, dim=-1), flip(B_low_feat_pair, dim=-1)

            self.optimizer.zero_grad()
            self.D_optimizer.zero_grad()

            # Train seg network
            for param in self.D.parameters():
                param.requires_grad = False

            # Supervised loss
            seg_loss = self.criterion(A_output, A_target)
            main_loss = seg_loss

            # Unsupervised loss
            #ins_loss = 0.01 * self.instance_loss(B_output, B_output_pair)
            #main_loss += ins_loss

            # Train adversarial loss
            D_out = self.D(prob_2_entropy(F.softmax(B_output)))
            adv_loss = bce_loss(D_out, self.source_label)

            main_loss += self.config.lambda_adv * adv_loss
            main_loss.backward()

            # Train discriminator
            for param in self.D.parameters():
                param.requires_grad = True
            A_output_detach = A_output.detach()
            B_output_detach = B_output.detach()
            # source
            D_source = self.D(prob_2_entropy(F.softmax(A_output_detach)))
            source_loss = bce_loss(D_source, self.source_label)
            source_loss = source_loss / 2
            # target
            D_target = self.D(prob_2_entropy(F.softmax(B_output_detach)))
            target_loss = bce_loss(D_target, self.target_label)
            target_loss = target_loss / 2
            d_loss = source_loss + target_loss
            d_loss.backward()

            self.optimizer.step()
            self.D_optimizer.step()

            seg_loss_sum += seg_loss.item()
            #ins_loss_sum += ins_loss.item()
            adv_loss_sum += self.config.lambda_adv * adv_loss.item()
            d_loss_sum += d_loss.item()

            #train_loss += seg_loss.item() + self.config.lambda_adv * adv_loss.item()
            train_loss += seg_loss.item()
            self.summary.writer.add_scalar('Train/SegLoss', seg_loss.item(),
                                           itr)
            #self.summary.writer.add_scalar('Train/InsLoss', ins_loss.item(), itr)
            self.summary.writer.add_scalar('Train/AdvLoss', adv_loss.item(),
                                           itr)
            self.summary.writer.add_scalar('Train/DiscriminatorLoss',
                                           d_loss.item(), itr)
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            # Show the results of the last iteration
            #if i == len(self.train_loader)-1:
        print("Add Train images at epoch" + str(epoch))
        self.summary.visualize_image('Train-Source', self.config.dataset,
                                     A_image, A_target, A_output, epoch, 5)
        self.summary.visualize_image('Train-Target', self.config.target,
                                     B_image, B_target, B_output, epoch, 5)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)
        #print('Seg Loss: %.3f' % seg_loss_sum)
        #print('Ins Loss: %.3f' % ins_loss_sum)
        #print('BN Loss: %.3f' % bn_loss_sum)
        #print('Adv Loss: %.3f' % adv_loss_sum)
        #print('Discriminator Loss: %.3f' % d_loss_sum)

        #if self.visdom:
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([seg_loss_sum]), win='train_loss', name='Seg_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([ins_loss_sum]), win='train_loss', name='Ins_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([bn_loss_sum]), win='train_loss', name='BN_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss_sum]), win='train_loss', name='Adv_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([d_loss_sum]), win='train_loss', name='Dis_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)

    def validation(self, epoch):
        def get_metrics(tbar, if_source=False):
            self.evaluator.reset()
            test_loss = 0.0
            #feat_mean, low_feat_mean, feat_var, low_feat_var = 0, 0, 0, 0
            #adv_loss = 0.0
            for i, sample in enumerate(tbar):
                image, target = sample['image'], sample['label']

                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()

                with torch.no_grad():
                    output, low_feat, feat = self.model(image)

                #low_feat = low_feat.cpu().numpy()
                #feat = feat.cpu().numpy()

                #if isinstance(feat, np.ndarray):
                #    feat_mean += feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    low_feat_mean += low_feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    feat_var += feat.var(axis=0).var(axis=1).var(axis=1)
                #    low_feat_var += low_feat.var(axis=0).var(axis=1).var(axis=1)
                #else:
                #    feat_mean = feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    low_feat_mean = low_feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    feat_var = feat.var(axis=0).var(axis=1).var(axis=1)
                #    low_feat_var = low_feat.var(axis=0).var(axis=1).var(axis=1)

                #d_output = self.D(prob_2_entropy(F.softmax(output)))
                #adv_loss += bce_loss(d_output, self.source_label).item()
                loss = self.criterion(output, target)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
                pred = output.data.cpu().numpy()

                target_ = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)

                # Add batch sample into evaluator
                self.evaluator.add_batch(target_, pred)
            if if_source:
                print("Add Validation-Source images at epoch" + str(epoch))
                self.summary.visualize_image('Val-Source', self.config.dataset,
                                             image, target, output, epoch, 5)
            else:
                print("Add Validation-Target images at epoch" + str(epoch))
                self.summary.visualize_image('Val-Target', self.config.target,
                                             image, target, output, epoch, 5)
            #feat_mean /= (i+1)
            #low_feat_mean /= (i+1)
            #feat_var /= (i+1)
            #low_feat_var /= (i+1)
            #adv_loss /= (i+1)
            # Fast test during the training
            Acc = self.evaluator.Building_Acc()
            IoU = self.evaluator.Building_IoU()
            mIoU = self.evaluator.Mean_Intersection_over_Union()

            if if_source:
                print('Validation on source:')
            else:
                print('Validation on target:')
            print('[Epoch: %d, numImages: %5d]' %
                  (epoch, i * self.config.batch_size + image.data.shape[0]))
            print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU))
            print('Loss: %.3f' % test_loss)

            if if_source:
                names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
                self.summary.writer.add_scalar('Val/SourceAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/SourceIoU', IoU, epoch)
            else:
                names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']
                self.summary.writer.add_scalar('Val/TargetAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/TargetIoU', IoU, epoch)
            # Draw Visdom
            #if if_source:
            #    names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
            #else:
            #    names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']

            #if self.visdom:
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([test_loss]), win='val_loss', name=names[0],
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss]), win='val_loss', name='adv_loss',
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([Acc]), win='metrics', name=names[1],
            #                  opts=dict(title='metrics', xlabel='epoch', ylabel='performance'),
            #                  update='append' if epoch > 0 else None)
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([IoU]), win='metrics', name=names[2],
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([mIoU]), win='metrics', name=names[3],
            #                  update='append')

            return Acc, IoU, mIoU

        self.model.eval()
        tbar_source = tqdm(self.val_loader, desc='\r')
        tbar_target = tqdm(self.target_val_loader, desc='\r')
        s_acc, s_iou, s_miou = get_metrics(tbar_source, True)
        t_acc, t_iou, t_miou = get_metrics(tbar_target, False)

        new_pred_source = s_iou
        new_pred_target = t_iou

        if new_pred_source > self.best_pred_source or new_pred_target > self.best_pred_target:
            is_best = True
            self.best_pred_source = max(new_pred_source, self.best_pred_source)
            self.best_pred_target = max(new_pred_target, self.best_pred_target)
        print('Saving state, epoch:', epoch)
        torch.save(
            self.model.module.state_dict(),
            self.args.save_folder + 'models/' + 'epoch' + str(epoch) + '.pth')
        loss_file = {
            's_Acc': s_acc,
            's_IoU': s_iou,
            's_mIoU': s_miou,
            't_Acc': t_acc,
            't_IoU': t_iou,
            't_mIoU': t_miou
        }
        with open(
                os.path.join(self.args.save_folder, 'eval',
                             'epoch' + str(epoch) + '.json'), 'w') as f:
            json.dump(loss_file, f)
Esempio n. 20
0
class Test:
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='test',
                                            img_root=config.img_root,
                                            source_dist=dist[config.dataset])
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city,
                                     split='test',
                                     img_root=config.img_root,
                                     source_dist=dist[city])
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            train = spacenet.Spacenet(city=city,
                                      split='train',
                                      img_root=config.img_root,
                                      source_dist=dist[city])
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=False)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

    def get_performance(self, dataloader, trainloader, city):
        # change mean and var of bn to adapt to the target domain
        if self.bn and city != self.config.dataset:
            print('BN Adaptation on ' + city)
            self.model.train()
            for sample in trainloader:
                image, target = sample['image'], sample['label']
                if self.cuda:
                    image, target = image.cuda(), target.cuda()
                with torch.no_grad():
                    output = self.model(image)
                    if len(output) > 1:
                        output = output[0]

        batch = self.save_batch
        if batch < 0:
            batch = len(dataloader)
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(dataloader, desc='\r')

        # save in different directories
        if self.bn:
            save_path = os.path.join(self.save_path, city + '_bn')
        else:
            save_path = os.path.join(self.save_path, city)

    # evaluate on the test dataset
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
                if len(output) > 1:
                    output = output[0]
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # save pictures
            if batch > 0:
                if not os.path.exists(self.save_path):
                    os.mkdir(self.save_path)
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                image = image.cpu().numpy() * 255
                image = image.transpose(0, 2, 3, 1).astype(int)

                imgs = self.color_images(pred, target)
                self.save_images(imgs, batch, save_path, False)
                self.save_images(image, batch, save_path, True)
                batch -= 1

        Acc = self.evaluator.Building_Acc()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        return Acc, IoU, mIoU

    def test(self):
        A, I, Im = self.get_performance(self.source_loader, None,
                                        self.config.dataset)
        tA, tI, tIm = [], [], []
        for dl, tl, city in zip(self.target_loader, self.target_trainloader,
                                self.target):
            tA_, tI_, tIm_ = self.get_performance(dl, tl, city)
            tA.append(tA_)
            tI.append(tI_)
            tIm.append(tIm_)

        res = {}
        print("Test for source domain:")
        print("{}: Acc:{}, IoU:{}, mIoU:{}".format(self.config.dataset, A, I,
                                                   Im))
        res[config.dataset] = {'Acc': A, 'IoU': I, 'mIoU': Im}

        print('Test for target domain:')
        for i, city in enumerate(self.target):
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(city, tA[i], tI[i],
                                                       tIm[i]))
            res[city] = {'Acc': tA[i], 'IoU': tI[i], 'mIoU': tIm[i]}

        if self.bn:
            name = 'train_log/test_bn.json'
        else:
            name = 'train_log/test.json'

        with open(name, 'w') as f:
            json.dump(res, f)

    def save_images(self, imgs, batch_index, save_path, if_original=False):
        for i, img in enumerate(imgs):
            img = img[:, :, ::-1]  # change to BGR
            #from IPython import embed
            #embed()
            if if_original:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Original.jpg'),
                    img)
            else:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Pred.jpg'), img)

    def color_images(self, pred, target):
        imgs = []
        for p, t in zip(pred, target):
            tmp = p * 2 + t
            np.squeeze(tmp)
            img = np.zeros((p.shape[0], p.shape[1], 3))
            # bkg:negative, building:postive
            #from IPython import embed
            #embed()
            img[np.where(tmp == 0)] = [0, 0, 0]  # Black RGB, for true negative
            img[np.where(tmp == 1)] = [255, 0,
                                       0]  # Red RGB, for false negative
            img[np.where(tmp == 2)] = [0, 255,
                                       0]  # Green RGB, for false positive
            img[np.where(tmp == 3)] = [255, 255,
                                       0]  #Yellow RGB, for true positive
            imgs.append(img)
        return imgs
Esempio n. 21
0
class Trainer(object):
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.visdom = args.visdom
        if args.visdom:
            self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888)
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': config.lr * 10
        }]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('train_log')
        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        self.best_pred_source = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            if self.visdom:
                self.vis.line(
                    X=torch.tensor([itr]),
                    Y=torch.tensor([self.optimizer.param_groups[0]['lr']]),
                    win='lr',
                    opts=dict(title='lr', xlabel='iter', ylabel='lr'),
                    update='append' if itr > 0 else None)
            A_image, A_target = sample['image'], sample['label']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source, 0)

            A_output, A_feat, A_low_feat = self.model(A_image)

            self.optimizer.zero_grad()

            # Supervised loss
            seg_loss = self.criterion(A_output, A_target)
            loss = seg_loss
            loss.backward()

            self.optimizer.step()

            train_loss += seg_loss.item()
            self.summary.writer.add_scalar('Train/Loss', loss.item(), itr)
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        print('Seg Loss: %.3f' % train_loss)

        if self.visdom:
            self.vis.line(X=torch.tensor([epoch]),
                          Y=torch.tensor([seg_loss_sum]),
                          win='train_loss',
                          name='Seg_loss',
                          opts=dict(title='loss',
                                    xlabel='epoch',
                                    ylabel='loss'),
                          update='append' if epoch > 0 else None)

    def validation(self, epoch):
        def get_metrics(tbar, if_source=False):
            self.evaluator.reset()
            test_loss = 0.0
            for i, sample in enumerate(tbar):
                image, target = sample['image'], sample['label']

                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()

                with torch.no_grad():
                    output, low_feat, feat = self.model(image)

                loss = self.criterion(output, target)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
                pred = output.data.cpu().numpy()

                target = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)

                # Add batch sample into evaluator
                self.evaluator.add_batch(target, pred)

            self.summary.writer.add_scalar('Val/Loss', test_loss / (i + 1),
                                           epoch)
            # Fast test during the training
            Acc = self.evaluator.Building_Acc()
            IoU = self.evaluator.Building_IoU()
            mIoU = self.evaluator.Mean_Intersection_over_Union()

            if if_source:
                print('Validation on source:')
            else:
                print('Validation on target:')
            print('[Epoch: %d, numImages: %5d]' %
                  (epoch, i * self.config.batch_size + image.data.shape[0]))
            print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU))
            print('Loss: %.3f' % test_loss)

            if if_source:
                names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
                self.summary.writer.add_scalar('Val/SourceAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/SourceIoU', IoU, epoch)
            else:
                names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']
                self.summary.writer.add_scalar('Val/TargetAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/TargetIoU', IoU, epoch)

            # Draw Visdom
            if self.visdom:
                self.vis.line(X=torch.tensor([epoch]),
                              Y=torch.tensor([test_loss]),
                              win='val_loss',
                              name=names[0],
                              update='append')
                self.vis.line(X=torch.tensor([epoch]),
                              Y=torch.tensor([Acc]),
                              win='metrics',
                              name=names[1],
                              opts=dict(title='metrics',
                                        xlabel='epoch',
                                        ylabel='performance'),
                              update='append' if epoch > 0 else None)
                self.vis.line(X=torch.tensor([epoch]),
                              Y=torch.tensor([IoU]),
                              win='metrics',
                              name=names[2],
                              update='append')
                self.vis.line(X=torch.tensor([epoch]),
                              Y=torch.tensor([mIoU]),
                              win='metrics',
                              name=names[3],
                              update='append')

            return Acc, IoU, mIoU

        self.model.eval()
        tbar_source = tqdm(self.val_loader, desc='\r')
        s_acc, s_iou, s_miou = get_metrics(tbar_source, True)

        new_pred_source = s_iou

        if new_pred_source > self.best_pred_source:
            is_best = True
            self.best_pred_source = max(new_pred_source, self.best_pred_source)
            print('Saving state, epoch:', epoch)
            torch.save(
                self.model.module.state_dict(), self.args.save_folder +
                'models/' + 'epoch' + str(epoch) + '.pth')
        loss_file = {'s_Acc': s_acc, 's_IoU': s_iou, 's_mIoU': s_miou}
        with open(
                os.path.join(self.args.save_folder, 'eval',
                             'epoch' + str(epoch) + '.json'), 'w') as f:
            json.dump(loss_file, f)
Esempio n. 22
0
evaluator = Evaluator(args.num_class)
for idx, (img, tag) in enumerate(val_data):
    predicts = ss_experiment.run_predicted_iter(img, tag)
    for i in range(args.batch_size):
        image = img[i]
        target = tag[i]
        predict = predicts[i]
        image = image.numpy()
        target = target.numpy().astype(np.uint8)
        image = image.transpose((1, 2, 0))
        image *= (0.229, 0.224, 0.225)
        image += (0.485, 0.456, 0.406)
        image *= 255.0
        image = image.astype(np.uint8)
        evaluator.add_batch(target, predict)
        miou = evaluator.Mean_Intersection_over_Union()
        acc = evaluator.Pixel_Accuracy()
        target = prediction.decode_segmap(target)
        predict = prediction.decode_segmap(predict)
        fig = plt.figure()
        fig.suptitle('Image size: {} Miou: {:.4f} Pixelacc {:.4f}'.format(
            args.crop_size, miou, acc))
        plt.tight_layout()
        plt.subplot(131)
        plt.imshow(image)
        plt.axis('off')
        plt.subplot(132)
        plt.imshow(target)
        plt.axis('off')
        plt.subplot(133)
        plt.imshow(predict)
Esempio n. 23
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                raise NotImplementedError
                #if so, which trainloader to use?
                # weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass, 12, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)
        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # mixed precision
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2'
                                           or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module,
                                  torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape,
                                           dtype=module.running_var.dtype,
                                           device=module.running_var.device),
                                requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape,
                                            dtype=module.running_var.dtype,
                                            device=module.running_var.device),
                                requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer,
                         self.architect_optimizer] = amp.initialize(
                             self.model,
                             [self.optimizer, self.architect_optimizer],
                             opt_level=self.opt_level,
                             keep_batchnorm_fp32=keep_batchnorm_fp32,
                             loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) > 1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print(
                    'currently cannot run with nn.DataParallel and optimization level',
                    self.opt_level)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        #checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1 or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(),
                                    checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.state_dict(),
                                    checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()

            if epoch >= self.args.alpha_epoch:
                search = next(iter(self.train_loaderB))
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda(
                    ), target_search.cuda()

                self.architect_optimizer.zero_grad()
                output_search = self.model(image_search)
                arch_loss = self.criterion(output_search, target_search)
                if self.use_amp:
                    with amp.scale_loss(
                            arch_loss,
                            self.architect_optimizer) as arch_scaled_loss:
                        arch_scaled_loss.backward()
                else:
                    arch_loss.backward()
                self.architect_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            #self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

            #torch.cuda.empty_cache()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Esempio n. 24
0
class Trainer(object):
    def __init__(self,
                 batch_size=32,
                 optimizer_name="Adam",
                 lr=1e-3,
                 weight_decay=1e-5,
                 epochs=200,
                 model_name="model01",
                 gpu_ids=None,
                 resume=None,
                 tqdm=None,
                 is_develop=False):
        """
        args:
            batch_size = (int) batch_size of training and validation
            lr = (float) learning rate of optimization
            weight_decay = (float) weight decay of optimization
            epochs = (int) The number of epochs of training
            model_name = (string) The name of training model. Will be folder name.
            gpu_ids = (List) List of gpu_ids. (e.g. gpu_ids = [0, 1]). Use CPU, if it is None. 
            resume = (Dict) Dict of some settings. (resume = {"checkpoint_path":PATH_of_checkpoint, "fine_tuning":True or False}). 
                     Learn from scratch, if it is None.
            tqdm = (tqdm Object) progress bar object. Set your tqdm please.
                   Don't view progress bar, if it is None.
        """
        # Set params
        self.batch_size = batch_size
        self.epochs = epochs
        self.start_epoch = 0
        self.use_cuda = (gpu_ids is not None) and torch.cuda.is_available
        self.tqdm = tqdm
        self.use_tqdm = tqdm is not None
        # Define Utils. (No need to Change.)
        """
        These are Project Modules.
        You may not have to change these.
        
        Saver: Save model weight. / <utils.saver.Saver()>
        TensorboardSummary: Write tensorboard file. / <utils.summaries.TensorboardSummary()>
        Evaluator: Calculate some metrics (e.g. Accuracy). / <utils.metrics.Evaluator()>
        """
        ## ***Define Saver***
        self.saver = Saver(model_name, lr, epochs)
        self.saver.save_experiment_config()

        ## ***Define Tensorboard Summary***
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # ------------------------- #
        # Define Training components. (You have to Change!)
        """
        These are important setting for training.
        You have to change these.
        
        make_data_loader: This creates some <Dataloader>s. / <dataloader.__init__>
        Modeling: You have to define your Model. / <modeling.modeling.Modeling()>
        Evaluator: You have to define Evaluator. / <utils.metrics.Evaluator()>
        Optimizer: You have to define Optimizer. / <utils.optimizer.Optimizer()>
        Loss: You have to define Loss function. / <utils.loss.Loss()>
        """
        ## ***Define Dataloader***
        self.train_loader, self.val_loader, self.test_loader, self.num_classes = make_data_loader(
            batch_size, is_develop=is_develop)

        ## ***Define Your Model***
        self.model = Modeling(self.num_classes)

        ## ***Define Evaluator***
        self.evaluator = Evaluator(self.num_classes)

        ## ***Define Optimizer***
        self.optimizer = Optimizer(self.model.parameters(),
                                   optimizer_name=optimizer_name,
                                   lr=lr,
                                   weight_decay=weight_decay)

        ## ***Define Loss***
        self.criterion = SegmentationLosses(
            weight=torch.tensor([1.0, 1594.0]).cuda()).build_loss('ce')
        # self.criterion = SegmentationLosses().build_loss('focal')
        #  self.criterion = BCEDiceLoss()
        # ------------------------- #
        # Some settings
        """
        You don't have to touch bellow code.
        
        Using cuda: Enable to use cuda if you want.
        Resuming checkpoint: You can resume training if you want.
        """
        ## ***Using cuda***
        if self.use_cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=gpu_ids).cuda()

        ## ***Resuming checkpoint***
        """You can ignore bellow code."""
        self.best_pred = 0.0
        if resume is not None:
            if not os.path.isfile(resume["checkpoint_path"]):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    resume["checkpoint_path"]))
            checkpoint = torch.load(resume["checkpoint_path"])
            self.start_epoch = checkpoint['epoch']
            if self.use_cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if resume["fine_tuning"]:
                # resume params of optimizer, if run fine tuning.
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.start_epoch = 0
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume["checkpoint_path"], checkpoint['epoch']))

    def _run_epoch(self,
                   epoch,
                   mode="train",
                   leave_progress=True,
                   use_optuna=False):
        """
        run training or validation 1 epoch.
        You don't have to change almost of this method.
        
        args:
            epoch = (int) How many epochs this time.
            mode = {"train" or "val"}
            leave_progress = {True or False} Can choose whether leave progress bar or not.
            use_optuna = {True or False} Can choose whether use optuna or not.
        
        Change point (if you need):
        - Evaluation: You can change metrics of monitoring.
        - writer.add_scalar: You can change metrics to be saved in tensorboard.
        """
        # ------------------------- #
        leave_progress = leave_progress and not use_optuna
        # Initializing
        epoch_loss = 0.0
        ## Set model mode & tqdm (progress bar; it wrap dataloader)
        assert (mode == "train") or (
            mode == "val"
        ), "argument 'mode' can be 'train' or 'val.' Not {}.".format(mode)
        if mode == "train":
            data_loader = self.tqdm(
                self.train_loader,
                leave=leave_progress) if self.use_tqdm else self.train_loader
            self.model.train()
            num_dataset = len(self.train_loader)
        elif mode == "val":
            data_loader = self.tqdm(
                self.val_loader,
                leave=leave_progress) if self.use_tqdm else self.val_loader
            self.model.eval()
            num_dataset = len(self.val_loader)
        ## Reset confusion matrix of evaluator
        self.evaluator.reset()

        # ------------------------- #
        # Run 1 epoch
        for i, sample in enumerate(data_loader):
            ## ***Get Input data***
            inputs, target = sample["input"], sample["label"]
            if self.use_cuda:
                inputs, target = inputs.cuda(), target.cuda()

            ## ***Calculate Loss <Train>***
            if mode == "train":
                self.optimizer.zero_grad()
                output = self.model(inputs)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
            ## ***Calculate Loss <Validation>***
            elif mode == "val":
                with torch.no_grad():
                    output = self.model(inputs)
                loss = self.criterion(output, target)
            epoch_loss += loss.item()
            ## ***Report results***
            if self.use_tqdm:
                data_loader.set_description('{} loss: {:.3f}'.format(
                    mode, epoch_loss / (i + 1)))
            ## ***Add batch results into evaluator***
            target = target.cpu().numpy()
            output = torch.argmax(output, axis=1).data.cpu().numpy()
            self.evaluator.add_batch(target, output)

        ## **********Evaluate Score**********
        """You can add new metrics! <utils.metrics.Evaluator()>"""
        # Acc = self.evaluator.Accuracy()
        miou = self.evaluator.Mean_Intersection_over_Union()

        if not use_optuna:
            ## ***Save eval into Tensorboard***
            self.writer.add_scalar('{}/loss_epoch'.format(mode),
                                   epoch_loss / (i + 1), epoch)
            # self.writer.add_scalar('{}/Acc'.format(mode), Acc, epoch)
            self.writer.add_scalar('{}/miou'.format(mode), miou, epoch)
            print('Total {} loss: {:.3f}'.format(mode,
                                                 epoch_loss / num_dataset))
            print("{0} mIoU:{1:.2f}".format(mode, miou))

        # Return score to watch. (update checkpoint or optuna's objective)
        return miou

    def run(self, leave_progress=True, use_optuna=False):
        """
        Run all epochs of training and validation.
        """
        for epoch in tqdm(range(self.start_epoch, self.epochs)):
            print(pycolor.GREEN + "[Epoch: {}]".format(epoch) + pycolor.END)

            ## ***Train***
            print(pycolor.YELLOW + "Training:" + pycolor.END)
            self._run_epoch(epoch,
                            mode="train",
                            leave_progress=leave_progress,
                            use_optuna=use_optuna)
            ## ***Validation***
            print(pycolor.YELLOW + "Validation:" + pycolor.END)
            score = self._run_epoch(epoch,
                                    mode="val",
                                    leave_progress=leave_progress,
                                    use_optuna=use_optuna)
            print("---------------------")
            if score > self.best_pred:
                print("model improve best score from {:.4f} to {:.4f}.".format(
                    self.best_pred, score))
                self.best_pred = score
                self.saver.save_checkpoint({
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    self.model.state_dict(),
                    'optimizer':
                    self.optimizer.state_dict(),
                    'best_pred':
                    self.best_pred,
                })
        self.writer.close()
        return self.best_pred
Esempio n. 25
0
class trainNew(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        cell_path = os.path.join(args.saved_arch_path, 'genotype.npy')
        network_path_space = os.path.join(args.saved_arch_path,
                                          'network_path_space.npy')

        new_cell_arch = np.load(cell_path)
        new_network_arch = np.load(network_path_space)

        # Define network
        model = newModel(network_arch=new_network_arch,
                         cell_arch=new_cell_arch,
                         num_classes=self.nclass,
                         num_layers=12)
        #                        output_stride=args.out_stride,
        #                        sync_bn=args.sync_bn,
        #                        freeze_bn=args.freeze_bn)
        self.decoder = Decoder(self.nclass, 'autodeeplab', args, False)
        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
        #        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        #
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(
            args.lr_scheduler, args.lr, args.epochs,
            len(self.train_loader))  #TODO: use min_lr ?

        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            encoder_output, low_level_feature = self.model(image)
            output = self.decoder(encoder_output, low_level_feature)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                encoder_output, low_level_feature = self.model(image)
                output = self.decoder(encoder_output, low_level_feature)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Esempio n. 26
0
    def warm_up(self, warmup_epochs):
        if warmup_epochs <= 0:
            self.logger.log('=> warmup close', mode='warm')
            #print('\twarmup close')
            return
        # set optimizer and scheduler in warm_up phase
        lr_max = self.arch_search_config.warmup_lr
        data_loader = self.run_manager.run_config.train_loader
        scheduler_params = self.run_manager.run_config.optimizer_config[
            'scheduler_params']
        optimizer_params = self.run_manager.run_config.optimizer_config[
            'optimizer_params']
        momentum, nesterov, weight_decay = optimizer_params[
            'momentum'], optimizer_params['nesterov'], optimizer_params[
                'weight_decay']
        eta_min = scheduler_params['eta_min']
        optimizer_warmup = torch.optim.SGD(self.net.weight_parameters(),
                                           lr_max,
                                           momentum,
                                           weight_decay=weight_decay,
                                           nesterov=nesterov)
        # set initial_learning_rate in weight_optimizer
        #for param_groups in self.run_manager.optimizer.param_groups:
        #    param_groups['lr'] = lr_max
        lr_scheduler_warmup = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_warmup, warmup_epochs, eta_min)
        iter_per_epoch = len(data_loader)
        total_iteration = warmup_epochs * iter_per_epoch

        self.logger.log('=> warmup begin', mode='warm')

        epoch_time = AverageMeter()
        end_epoch = time.time()
        for epoch in range(self.warmup_epoch, warmup_epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Warmup Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='warm')

            lr_scheduler_warmup.step(epoch)
            warmup_lr = lr_scheduler_warmup.get_lr()
            self.net.train()

            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1, warmup_epochs)
            time_left = epoch_time.average * (warmup_epochs - epoch)
            common_log = '[Warmup the {:}] Left={:} LR={:}'.format(
                epoch_str,
                str(timedelta(seconds=time_left)) if epoch != 0 else None,
                warmup_lr)
            self.logger.log(common_log, mode='warm')
            end = time.time()

            for i, (datas, targets) in enumerate(data_loader):

                #if i == 29: break

                if torch.cuda.is_available():
                    datas = datas.to(self.run_manager.device,
                                     non_blocking=True)
                    targets = targets.to(self.run_manager.device,
                                         non_blocking=True)
                else:
                    raise ValueError('do not support cpu version')
                data_time.update(time.time() - end)

                # get single_path in each iteration
                _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                )
                _, aspp_index = self.net.get_aspp_hardwts_index()
                single_path = self.net.sample_single_path(
                    self.run_manager.run_config.nb_layers, aspp_index,
                    network_index)
                logits = self.net.single_path_forward(datas, single_path)

                # update without entropy reg in warm_up phase
                loss = self.run_manager.criterion(logits, targets)

                # measure metrics and update
                evaluator = Evaluator(self.run_manager.run_config.nb_classes)
                evaluator.add_batch(targets, logits)
                acc = evaluator.Pixel_Accuracy()
                miou = evaluator.Mean_Intersection_over_Union()
                fscore = evaluator.Fx_Score()
                losses.update(loss.data.item(), datas.size(0))
                accs.update(acc, datas.size(0))
                mious.update(miou, datas.size(0))
                fscores.update(fscore, datas.size(0))

                self.net.zero_grad()
                loss.backward()
                self.run_manager.optimizer.step()

                batch_time.update(time.time() - end)
                end = time.time()
                if (
                        i + 1
                ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                    Wstr = '|*WARM-UP*|' + time_string(
                    ) + '[{:}][iter{:03d}/{:03d}]'.format(
                        epoch_str, i + 1, iter_per_epoch)
                    Tstr = '|Time     | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                        batch_time=batch_time, data_time=data_time)
                    Bstr = '|Base     | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'\
                        .format(loss=losses, acc=accs, miou=mious, fscore=fscores)
                    self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'warm')

            #torch.cuda.empty_cache()
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()
            '''
            # TODO: wheter perform validation after each epoch in warmup phase ?
            valid_loss, valid_acc, valid_miou, valid_fscore = self.validate()
            valid_log = 'Warmup Valid\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tF\t{5:6.4f}'\
                .format(epoch+1, warmup_epochs, valid_loss, valid_acc, valid_miou, valid_fscore)
                        #'\tflops\t{6:}M\tparams{7:}M'\
            valid_log += 'Train\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tFscore\t{5:6.4f}'
            self.run_manager.write_log(valid_log, 'valid')
            '''

            # continue warmup phrase
            self.warmup = epoch + 1 < warmup_epochs
            self.warmup_epoch = self.warmup_epoch + 1
            #self.start_epoch = self.warmup_epoch
            # To save checkpoint in warmup phase at specific frequency.
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch + 1) == warmup_epochs:
                state_dict = self.net.state_dict()
                # rm architecture parameters because, in warm_up phase, arch_parameters are not updated.
                #for key in list(state_dict.keys()):
                #    if 'cell_arch_parameters' in key or 'network_arch_parameters' in key or 'aspp_arch_parameters' in key:
                #        state_dict.pop(key)
                checkpoint = {
                    'state_dict': state_dict,
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.optimizer.state_dict(),
                    'warmup': self.warmup,
                    'warmup_epoch': epoch + 1,
                }
                filename = self.logger.path(mode='warm', is_best=False)
                save_path = save_checkpoint(checkpoint,
                                            filename,
                                            self.logger,
                                            mode='warm')
Esempio n. 27
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.train_dir = './data_list/train_lite.csv'
        self.train_list = pd.read_csv(self.train_dir)
        self.val_dir = './data_list/val_lite.csv'
        self.val_list = pd.read_csv(self.val_dir)
        self.train_length = len(self.train_list)
        self.val_length = len(self.val_list)
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # 方式2
        self.train_gen, self.val_gen, self.test_gen, self.nclass = make_data_loader2(args)
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)
        # optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)

        # Define Criterion
        # self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion1 = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='ce')
        self.criterion2= SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='dice')

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, self.train_length)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                # self.model.module.load_state_dict(checkpoint['state_dict'])
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        prev_time = time.time()
        self.model.train()
        self.evaluator.reset()

        num_img_tr = self.train_length / self.args.batch_size

        for iteration in range(int(num_img_tr)):
            samples = next(self.train_gen)
            image, target = samples['image'], samples['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, iteration, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, make_one_hot(target.long(), num_classes=self.nclass))
            loss = loss1 + loss2
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            self.writer.add_scalar('train/total_loss_iter', loss.item(), iteration + num_img_tr * epoch)


            # print log  默认log_iters = 4
            if iteration % 4 == 0:
                end_time = time.time()
                print("Iter - %d: train loss: %.3f, celoss: %.4f, diceloss: %.4f, time cost: %.3f s" \
                      % (iteration, loss.item(), loss1.item(), loss2.item(), end_time - prev_time))
                prev_time = time.time()

            # Show 10 * 3 inference results each epoch
            if iteration % (num_img_tr // 10) == 0:
                global_step = iteration + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        print("input image shape/iter:", image.shape)

        # train evaluate
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU = np.nanmean(IoU)
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc_tr:{}, Acc_class_tr:{}, IoU_tr:{}, mIoU_tr:{}, fwIoU_tr: {}".format(Acc, Acc_class, IoU, mIoU, FWIoU))

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)





    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        val_loss = 0.0
        prev_time = time.time()
        num_img_val = self.val_length / self.args.batch_size
        print("Validation:","epoch ", epoch)
        print(num_img_val)
        for iteration in range(int(num_img_val)):
            samples = next(self.val_gen)
            image, target = samples['image'], samples['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():  #
                output = self.model(image)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, make_one_hot(target.long(), num_classes=self.nclass))
            loss = loss1 + loss2
            val_loss += loss.item()
            self.writer.add_scalar('val/total_loss_iter', loss.item(), iteration + num_img_val * epoch)
            val_loss += loss.item()

            # print log  默认log_iters = 4
            if iteration % 4 == 0:
                end_time = time.time()
                print("Iter - %d: validation loss: %.3f, celoss: %.4f, diceloss: %.4f, time cost: %.3f s" \
                      % (iteration, loss.item(), loss1.item(), loss2.item(), end_time - prev_time))
                prev_time = time.time()


            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        print(image.shape)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU = np.nanmean(IoU)
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', val_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print("Acc_val:{}, Acc_class_val:{}, IoU:val:{}, mIoU_val:{}, fwIoU_val: {}".format(Acc, Acc_class, IoU, mIoU, FWIoU))
        print('Loss: %.3f' % val_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Esempio n. 28
0
    def train(self, fix_net_weights=False):

        # have config valid_batch_size, and ignored drop_last.
        data_loader = self.run_manager.run_config.train_loader
        iter_per_epoch = len(data_loader)
        total_iteration = iter_per_epoch * self.run_manager.run_config.epochs

        if fix_net_weights:  # used to debug
            data_loader = [(0, 0)] * iter_per_epoch
            print('Train Phase close for debug')

        # arch_parameter update frequency and times in each iteration.
        #update_schedule = self.arch_search_config.get_update_schedule(iter_per_epoch)

        # pay attention here, total_epochs include warmup epochs
        epoch_time = AverageMeter()
        end_epoch = time.time()
        # TODO : use start_epochs
        for epoch in range(self.start_epoch,
                           self.run_manager.run_config.epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Train Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='search')

            self.run_manager.scheduler.step(epoch)
            train_lr = self.run_manager.scheduler.get_lr()
            arch_lr = self.arch_optimizer.param_groups[0]['lr']
            self.net.set_tau(self.arch_search_config.tau_max -
                             (self.arch_search_config.tau_max -
                              self.arch_search_config.tau_min) * (epoch) /
                             (self.run_manager.run_config.epochs))
            tau = self.net.get_tau()
            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            #valid_data_time = AverageMeter()
            valid_losses = AverageMeter()
            valid_accs = AverageMeter()
            valid_mious = AverageMeter()
            valid_fscores = AverageMeter()

            self.net.train()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            time_left = epoch_time.average * (
                self.run_manager.run_config.epochs - epoch)
            common_log = '[*Train-Search* the {:}] Left={:} WLR={:} ALR={:} tau={:}'\
                .format(epoch_str, str(timedelta(seconds=time_left)) if epoch != 0 else None, train_lr, arch_lr, tau)
            self.logger.log(common_log, 'search')

            end = time.time()
            for i, (datas, targets) in enumerate(data_loader):
                #if i == 29: break
                if not fix_net_weights:
                    if torch.cuda.is_available():
                        datas = datas.to(self.run_manager.device,
                                         non_blocking=True)
                        targets = targets.to(self.run_manager.device,
                                             non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end)

                    # get single_path in each iteration
                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(datas, single_path)

                    # loss
                    loss = self.run_manager.criterion(logits, targets)
                    # metrics and update
                    evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    acc = evaluator.Pixel_Accuracy()
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    losses.update(loss.data.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()

                    self.run_manager.optimizer.step()

                    #end_valid = time.time()
                    valid_datas, valid_targets = self.run_manager.run_config.valid_next_batch
                    if torch.cuda.is_available():
                        valid_datas = valid_datas.to(self.run_manager.device,
                                                     non_blocking=True)
                        valid_targets = valid_targets.to(
                            self.run_manager.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')

                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(
                        valid_datas, single_path)

                    loss = self.run_manager.criterion(logits, valid_targets)

                    # metrics and update
                    valid_evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    valid_evaluator.add_batch(valid_targets, logits)
                    acc = valid_evaluator.Pixel_Accuracy()
                    miou = valid_evaluator.Mean_Intersection_over_Union()
                    fscore = valid_evaluator.Fx_Score()
                    valid_losses.update(loss.data.item(), datas.size(0))
                    valid_accs.update(acc.item(), datas.size(0))
                    valid_mious.update(miou.item(), datas.size(0))
                    valid_fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()
                    #if (i+1) % 5 == 0:
                    #    print('network_arch_parameters.grad in search phase:\n',
                    #                         self.net.network_arch_parameters.grad)
                    self.arch_optimizer.step()

                    # batch_time of one iter of train and valid.
                    batch_time.update(time.time() - end)
                    end = time.time()
                    if (
                            i + 1
                    ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                        Wstr = '|*Search*|' + time_string(
                        ) + '[{:}][iter{:03d}/{:03d}]'.format(
                            epoch_str, i + 1, iter_per_epoch)
                        Tstr = '|Time    | {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                            batch_time=batch_time, data_time=data_time)
                        Bstr = '|Base    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=losses, acc=accs, miou=mious, fscore=fscores)
                        Astr = '|Arch    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=valid_losses,
                            acc=valid_accs,
                            miou=valid_mious,
                            fscore=valid_fscores)
                        self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr +
                                        '\n' + Astr,
                                        mode='search')

            _, network_index = self.net.get_network_arch_hardwts_with_constraint(
            )  # set self.hardwts again
            _, aspp_index = self.net.get_aspp_hardwts_index()
            single_path = self.net.sample_single_path(
                self.run_manager.run_config.nb_layers, aspp_index,
                network_index)
            cell_arch_entropy, network_arch_entropy, total_entropy = self.net.calculate_entropy(
                single_path)

            # update visdom
            if self.vis is not None:
                self.vis.visdom_update(epoch, 'loss',
                                       [losses.average, valid_losses.average])
                self.vis.visdom_update(epoch, 'accuracy',
                                       [accs.average, valid_accs.average])
                self.vis.visdom_update(epoch, 'miou',
                                       [mious.average, valid_mious.average])
                self.vis.visdom_update(
                    epoch, 'f1score', [fscores.average, valid_fscores.average])

                self.vis.visdom_update(epoch, 'cell_entropy',
                                       [cell_arch_entropy])
                self.vis.visdom_update(epoch, 'network_entropy',
                                       [network_arch_entropy])
                self.vis.visdom_update(epoch, 'entropy', [total_entropy])

            #torch.cuda.empty_cache()
            # update epoch_time
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()

            epoch_str = '{:03d}/{:03d}'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            log = '[{:}] train :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n' \
                  '[{:}] valid :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n'.format(
                epoch_str, losses.average, accs.average, mious.average, fscores.average,
                epoch_str, valid_losses.average, valid_accs.average, valid_mious.average, valid_fscores.average
            )
            self.logger.log(log, mode='search')

            self.logger.log(
                '<<<---------->>> Super Network decoding <<<---------->>> ',
                mode='search')
            actual_path, cell_genotypes = self.net.network_cell_arch_decode()
            #print(cell_genotypes)
            new_genotypes = []
            for _index, genotype in cell_genotypes:
                xlist = []
                print(_index, genotype)
                for edge_genotype in genotype:
                    for (node_str, select_index) in edge_genotype:
                        xlist.append((node_str, self.run_manager.run_config.
                                      conv_candidates[select_index]))
                new_genotypes.append((_index, xlist))
            log_str = 'The {:} decode network:\n' \
                      'actual_path = {:}\n' \
                      'genotype:'.format(epoch_str, actual_path)
            for _index, genotype in new_genotypes:
                log_str += 'index: {:} arch: {:}\n'.format(_index, genotype)
            self.logger.log(log_str, mode='network_space', display=False)

            # TODO: perform save the best network ckpt
            # 1. save network_arch_parameters and cell_arch_parameters
            # 2. save weight_parameters
            # 3. weight_optimizer.state_dict
            # 4. arch_optimizer.state_dict
            # 5. training process
            # 6. monitor_metric and the best_value
            # get best_monitor in valid phase.
            val_monitor_metric = get_monitor_metric(
                self.run_manager.monitor_metric, valid_losses.average,
                valid_accs.average, valid_mious.average, valid_fscores.average)
            is_best = self.run_manager.best_monitor < val_monitor_metric
            self.run_manager.best_monitor = max(self.run_manager.best_monitor,
                                                val_monitor_metric)
            # 1. if is_best : save_current_ckpt
            # 2. if can be divided : save_current_ckpt

            #self.run_manager.save_model(epoch, {
            #    'arch_optimizer': self.arch_optimizer.state_dict(),
            #}, is_best=True, checkpoint_file_name=None)
            # TODO: have modification on checkpoint_save semantics
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch +
                        1) == self.run_manager.run_config.epochs or is_best:
                checkpoint = {
                    'state_dict':
                    self.net.state_dict(),
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.scheduler.state_dict(),
                    'arch_optimizer':
                    self.arch_optimizer.state_dict(),
                    'best_monitor': (self.run_manager.monitor_metric,
                                     self.run_manager.best_monitor),
                    'warmup':
                    False,
                    'start_epochs':
                    epoch + 1,
                }
                checkpoint_arch = {
                    'actual_path': actual_path,
                    'cell_genotypes': cell_genotypes,
                }
                filename = self.logger.path(mode='search', is_best=is_best)
                filename_arch = self.logger.path(mode='arch', is_best=is_best)
                save_checkpoint(checkpoint,
                                filename,
                                self.logger,
                                mode='search')
                save_checkpoint(checkpoint_arch,
                                filename_arch,
                                self.logger,
                                mode='arch')
Esempio n. 29
0
class Test:
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            #test_img_root = '/home/home1/swarnakr/main/DomainAdaptation/satellite/' + city + '/' + 'test'
            #test = spacenet.Spacenet(city=city, split='test', img_root=test_img_root)
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            #train_img_root = '/home/home1/swarnakr/main/DomainAdaptation/satellite/' + city + '/' + 'train'
            #train = spacenet.Spacenet(city=city, split='train', img_root=train_img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

    def get_performance(self, dataloader, trainloader, city):
        # change mean and var of bn to adapt to the target domain
        if self.bn and city != self.config.dataset:
            print('BN Adaptation on' + city)
            self.model.train()

            # print layer params
            if 0:
                layr = 0
                for h in self.model.modules():
                    if isinstance(h, nn.Conv2d):
                        k1 = h.kernel_size[0]
                        k2 = h.kernel_size[1]
                        ch = h.out_channels
                        print(
                            'L={}   & ${} \\times {} \\times {}$ \\\\ \\hline'.
                            format(layr, k1, k2, ch))
                        layr += 1

            for sample in trainloader:
                image, target = sample['image'], sample['label']
                if self.cuda:
                    image, target = image.cuda(), target.cuda()
                with torch.no_grad():
                    output = self.model(image)
                #pdb.set_trace()
        batch = self.save_batch
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(dataloader, desc='\r')

        # save in different directories
        if self.bn:
            save_path = os.path.join(self.save_path, city + '_bn')
        else:
            save_path = os.path.join(self.save_path, city)

    # evaluate on the test dataset
        bn = dict()
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

    # save BN params
            layr = 0
            for h in self.model.modules():
                if isinstance(
                        h, nn.Conv2d
                ):  #this is ok to do, since there is one conv2d layer corresponding to each BN layer.
                    bn[(layr), 'weight'] = np.squeeze(h.weight)
                if isinstance(h, nn.BatchNorm2d):  #Conv2d):
                    bn[(layr, 'mean')] = h.running_mean
                    bn[(layr, 'var')] = h.running_var
                    layr += 1

            #pdb.set_trace()
            if not os.path.exists(self.save_path):
                os.mkdir(self.save_path)

            torch.save(bn, os.path.join(save_path, 'bnAll.pth'))
            break
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # save pictures
            if batch > 0:
                if not os.path.exists(self.save_path):
                    os.mkdir(self.save_path)
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                image = image.cpu().numpy() * 255
                image = image.transpose(0, 2, 3, 1).astype(int)

                imgs = self.color_images(pred, target)
                self.save_images(imgs, batch, save_path, False)
                self.save_images(image, batch, save_path, True)
                batch -= 1

        Acc = self.evaluator.Building_Acc()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        return Acc, IoU, mIoU

    def test(self):
        #        A=[]; I=[]; Im=[];
        #A, I, Im = self.get_performance(self.source_loader, None, self.config.dataset)
        tA, tI, tIm = [], [], []
        for dl, tl, city in zip(self.target_loader, self.target_trainloader,
                                self.target):
            #if city != 'Vegas':
            tA_, tI_, tIm_ = self.get_performance(dl, tl, city)
            tA.append(tA_)
            tI.append(tI_)
            tIm.append(tIm_)

        res = {}
        print("Test for source domain:")
        print("{}: Acc:{}, IoU:{}, mIoU:{}".format(self.config.dataset, A, I,
                                                   Im))
        res[config.dataset] = {'Acc': A, 'IoU': I, 'mIoU': Im}

        print('Test for target domain:')
        for i, city in enumerate(self.target):
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(city, tA[i], tI[i],
                                                       tIm[i]))
            res[city] = {'Acc': tA[i], 'IoU': tI[i], 'mIoU': tIm[i]}

        if self.bn:
            name = 'train_log/test_bn.json'
        else:
            name = 'train_log/test.json'

        with open(name, 'w') as f:
            json.dump(res, f)

    def save_images(self, imgs, batch_index, save_path, if_original=False):
        for i, img in enumerate(imgs):
            img = img[:, :, ::-1]  # change to BGR
            #from IPython import embed
            #embed()
            if not if_original:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Original.jpg'),
                    img)
            else:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Pred.jpg'), img)

    def color_images(self, pred, target):
        imgs = []
        for p, t in zip(pred, target):
            tmp = p * 2 + t
            np.squeeze(tmp)
            img = np.zeros((p.shape[0], p.shape[1], 3))
            # bkg:negative, building:postive
            #from IPython import embed
            #embed()
            img[np.where(tmp == 0)] = [0, 0, 0]  # Black RGB, for true negative
            img[np.where(tmp == 1)] = [255, 0,
                                       0]  # Red RGB, for false negative
            img[np.where(tmp == 2)] = [0, 255,
                                       0]  # Green RGB, for false positive
            img[np.where(tmp == 3)] = [255, 255,
                                       0]  #Yellow RGB, for true positive
            imgs.append(img)
        return imgs
Esempio n. 30
0
class Test:
    def __init__(self, model_path, config, cuda=False):
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.target_set = []
        self.target_loader = []
        self.target_trainset = []
        self.target_trainloader = []
        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city,
                                     split='test',
                                     img_root=config.img_root)
            train = spacenet.Spacenet(city=city,
                                      split='train',
                                      img_root=config.img_root)
            self.target_set.append(test)
            self.target_trainset.append(train)

            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

    def get_performance(self, dataloader, trainloader, if_source):
        # change mean and var of bn to adapt to the target domain
        if not if_source:
            self.model.train()
            for sample in trainloader:
                image, target = sample['image'], sample['label']
                if self.cuda:
                    image, target = image.cuda(), target.cuda()
                with torch.no_grad():
                    output = self.model(image)

        # evaluate the model
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(dataloader, desc='\r')
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Building_Acc()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        return Acc, IoU, mIoU

    def test(self):
        A, I, Im = self.get_performance(self.source_loader, None, True)
        tA, tI, tIm = [], [], []
        for dl, tl in zip(self.target_loader, self.target_trainloader):
            tA_, tI_, tIm_ = self.get_performance(dl, tl, False)
            tA.append(tA_)
            tI.append(tI_)
            tIm.append(tIm_)

        res = {}
        print("Test for source domain:")
        print("{}: Acc:{}, IoU:{}, mIoU:{}".format(config.dataset, A, I, Im))
        res[config.dataset] = {'Acc': A, 'IoU': I, 'mIoU': Im}

        print('Test for target domain:')
        for i, city in enumerate(self.target):
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(city, tA[i], tI[i],
                                                       tIm[i]))
            res[city] = {'Acc': tA[i], 'IoU': tI[i], 'mIoU': tIm[i]}
        with open('train_log/test_bn.json', 'w') as f:
            json.dump(res, f)