class Test:
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city,
                                     split='test',
                                     img_root=config.img_root)
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            train = spacenet.Spacenet(city=city,
                                      split='train',
                                      img_root=config.img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())

        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

    def save_output(module, input, output):
        global activation, i
        # save output
        print('I came here')
        pdb.set_trace()
        channels = output.permute(1, 0, 2, 3)
        c = channels.shape[0]
        features = channels.reshape(c, -1)
        if len(activation) == i:
            activation.append(features)
        else:
            activation[i] = torch.cat([activation[i], features], dim=1)
        i += 1
        return

    def get_performance(self, dataloader, trainloader, city):
        # change mean and var of bn to adapt to the target domain
        #pdb.set_trace()
        if self.bn:
            save_path = os.path.join(self.save_path, city + '_bn')
        else:
            save_path = os.path.join(self.save_path, city)

        if self.bn and city != self.config.dataset:  #!= self.config.dataset:
            print('BN Adaptation on' + city)
            self.model.train()
            layr = 0
            for h in self.model.modules():
                if isinstance(h, nn.ReLU6):  #Conv2d):
                    layr += 1
                    if layr == 1:
                        h.register_forward_hook(save_output2)
                    if layr > 1:
                        break
            tbar = tqdm(dataloader, desc='\r')

            for sample in trainloader:
                #            for i, sample in enumerate(tbar):
                image, target = sample['image'], sample['label']
                pdb.set_trace()
                #           add0 = np.tile([10, 42, 37],(400,400,16,1));
                #          add = add0.transpose(2,3,0,1)
                #         image = image + add;
                if self.cuda:
                    image, target = image.cuda(), target.cuda()
                with torch.no_grad():
                    output = self.model(image)
                if not os.path.exists(save_path):
                    os.mkdir(save_path)

                pdb.set_trace()
                self.save_act(activation, save_path, False)
                self.save_act(image.numpy() * 255, save_path, True)

        batch = self.save_batch
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(dataloader, desc='\r')

        # save in different directories
        if self.bn:
            save_path = os.path.join(self.save_path, city + '_bn')
        else:
            save_path = os.path.join(self.save_path, city)

        layr = 0
        for h in self.model.modules():
            if isinstance(h, nn.ReLU6):  #Conv2d):
                layr += 1
                if layr == 2:
                    h.register_forward_hook(save_output2)
                if layr > 2:
                    break

    # evaluate on the test dataset
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            pdb.set_trace()
            if not os.path.exists(save_path):
                os.mkdir(save_path)

            self.save_act(activation, save_path, False)
            self.save_act(image.numpy() * 255, save_path, True)

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # save pictures
            if batch > 0:
                if not os.path.exists(self.save_path):
                    os.mkdir(self.save_path)
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                image = image.cpu().numpy() * 255
                image = image.transpose(0, 2, 3, 1).astype(int)

                imgs = self.color_images(pred, target)
                self.save_images(imgs, batch, save_path, False)
                self.save_images(image, batch, save_path, True)
                batch -= 1

        Acc = self.evaluator.Building_Acc()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        return Acc, IoU, mIoU

    def test(self):
        #A, I, Im = self.get_performance(self.source_loader, None, self.config.dataset)
        tA, tI, tIm = [], [], []
        for dl, tl, city in zip(self.target_loader, self.target_trainloader,
                                self.target):
            tA_, tI_, tIm_ = self.get_performance(dl, tl, city)
            tA.append(tA_)
            tI.append(tI_)
            tIm.append(tIm_)

        res = {}
        print("Test for source domain:")
        print("{}: Acc:{}, IoU:{}, mIoU:{}".format(self.config.dataset, A, I,
                                                   Im))
        res[config.dataset] = {'Acc': A, 'IoU': I, 'mIoU': Im}

        print('Test for target domain:')
        for i, city in enumerate(self.target):
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(city, tA[i], tI[i],
                                                       tIm[i]))
            res[city] = {'Acc': tA[i], 'IoU': tI[i], 'mIoU': tIm[i]}

        if self.bn:
            name = 'train_log/test_bn.json'
        else:
            name = 'train_log/test.json'

        with open(name, 'w') as f:
            json.dump(res, f)

    def save_act(self, imgs, save_path, ifImage):
        if ifImage:
            for i, img in enumerate(imgs):
                img = img.transpose(1, 2, 0)
                img = img[:, :, ::-1]
                cv2.imwrite(os.path.join(save_path, 'im' + str(i) + '.jpg'),
                            img)

        else:
            for i, img in enumerate(imgs):
                for j, act in enumerate(img):
                    cv2.imwrite(
                        os.path.join(save_path,
                                     'im' + str(i) + 'act' + str(j) + '.jpg'),
                        act.numpy() * 255)

    def save_images(self, imgs, batch_index, save_path, if_original=False):
        for i, img in enumerate(imgs):
            img = img[:, :, ::-1]  # change to BGR
            #from IPython import embed
            #embed()
            if not if_original:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Original.jpg'),
                    img)
            else:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Pred.jpg'), img)

    def color_images(self, pred, target):
        imgs = []
        for p, t in zip(pred, target):
            tmp = p * 2 + t
            np.squeeze(tmp)
            img = np.zeros((p.shape[0], p.shape[1], 3))
            # bkg:negative, building:postive
            #from IPython import embed
            #embed()
            img[np.where(tmp == 0)] = [0, 0, 0]  # Black RGB, for true negative
            img[np.where(tmp == 1)] = [255, 0,
                                       0]  # Red RGB, for false negative
            img[np.where(tmp == 2)] = [0, 255,
                                       0]  # Green RGB, for false positive
            img[np.where(tmp == 3)] = [255, 255,
                                       0]  #Yellow RGB, for true positive
            imgs.append(img)
        return imgs
def test(model, loader, device):
    model.eval()
    all_labels = []
    all_logits = []
    all_predictions = []
    all_losses = []
    all_seg_logits_interp = []
    all_seg_preds_interp = []
    all_dices = []
    all_ious = []
    evaluator = Evaluator(ex.current_run.config['model']['num_classes'])
    image_evaluator = Evaluator(ex.current_run.config['model']['num_classes'])

    pbar = tqdm(loader, ncols=80, desc='Test')

    pooling = ex.current_run.config['model']['pooling']
    if pooling in requires_gradients:
        grad_policy = torch.set_grad_enabled(True)
    else:
        grad_policy = torch.no_grad()

    is_ae = isinstance(model.backbone, ResNet_AE)

    with grad_policy:
        for i, (image, segmentation, label) in enumerate(pbar):
            image, label = image.to(device), label.to(device)

            if pooling in requires_gradients or pooling == 'ablation':
                model.pooling.eval_cams = True

            if is_ae:
                z, x_reconst = model.backbone(image)
                logits = model.pooling(z)
            else:
                logits = model(image)

            pred = model.pooling.predictions(logits=logits).item()

            loss = model.pooling.loss(logits=logits, labels=label)

            if ex.current_run.config['dataset']['name'] == 'caltech_birds':
                segmentation_classes = (segmentation.squeeze() > 0.5)
            else:
                segmentation_classes = (segmentation.squeeze() != 0)

            seg_logits = model.pooling.cam.detach().cpu()
            seg_logits_interp = F.interpolate(seg_logits,
                                              size=segmentation_classes.shape,
                                              mode='bilinear',
                                              align_corners=True).squeeze(0)

            label = label.item()
            all_labels.append(label)
            all_logits.append(logits.cpu())
            all_predictions.append(pred)
            all_losses.append(loss.item())

            if ex.current_run.config['dataset']['name'] == 'glas':
                if ex.current_run.config['model'][
                        'pooling'] == 'deepmil_multi':
                    seg_preds_interp = (seg_logits_interp[label] >
                                        (1 / seg_logits.numel())).cpu()
                else:
                    seg_preds_interp = (
                        seg_logits_interp.argmax(0) == label).cpu()
            else:
                if ex.current_run.config['model']['pooling'] == 'deepmil':
                    seg_preds_interp = (seg_logits_interp.squeeze(0) >
                                        (1 / seg_logits.numel())).cpu()
                elif ex.current_run.config['model'][
                        'pooling'] == 'deepmil_multi':
                    seg_preds_interp = (seg_logits_interp[label] >
                                        (1 / seg_logits.numel())).cpu()
                else:
                    seg_preds_interp = seg_logits_interp.argmax(0).cpu()

            # Save CAMs visualization
            save_dir = 'cams/{}/{}'.format(
                ex.current_run.config['model']['arch'] +
                str(ex.current_run.config['balance']),
                ex.current_run.config['model']['pooling'])
            os.makedirs(save_dir, exist_ok=True)
            file_path = os.path.join(save_dir, 'cam_{}.png'.format(i))
            seg_logits_interp_norm = seg_logits_interp / seg_logits_interp.max(
            )
            saliency_map_0, overlay_0 = visualize_cam(
                seg_logits_interp_norm[0], image)
            saliency_map_1, overlay_1 = visualize_cam(
                seg_logits_interp_norm[1], image)
            overlay = [overlay_0, overlay_1][label]
            save_visualization(image.squeeze().cpu(),
                               segmentation_classes.numpy(), saliency_map_0,
                               saliency_map_1, overlay,
                               seg_preds_interp.numpy() * 255, label,
                               file_path)

            if is_ae:
                x_reconst = x_reconst.detach()
                save_dir = 'reconst/{}/{}'.format(
                    ex.current_run.config['model']['arch'],
                    ex.current_run.config['model']['pooling'])
                os.makedirs(save_dir, exist_ok=True)
                file_path = os.path.join(save_dir, 'reconst_{}.png'.format(i))
                save_reconst(
                    image.squeeze(0).cpu(),
                    x_reconst.squeeze(0).cpu(), file_path)

            all_seg_logits_interp.append(seg_logits_interp.numpy())
            all_seg_preds_interp.append(
                seg_preds_interp.numpy().astype('bool'))

            evaluator.add_batch(segmentation_classes, seg_preds_interp)
            image_evaluator.add_batch(segmentation_classes, seg_preds_interp)
            all_dices.append(image_evaluator.dice()[1].item())
            all_ious.append(
                image_evaluator.intersection_over_union()[1].item())
            image_evaluator.reset()

        if pooling in requires_gradients or pooling == 'ablation':
            model.pooling.eval_cams = False
        all_logits = torch.cat(all_logits, 0)
        all_logits = all_logits.detach()
        all_probabilities = model.pooling.probabilities(all_logits)

    with open('test/gradcampp_seg_preds.pkl', 'wb') as f:
        pkl.dump(all_seg_preds_interp, f)

    results_dir = 'out/{}/{}'.format(
        ex.current_run.config['model']['arch'] +
        str(ex.current_run.config['balance']),
        ex.current_run.config['model']['pooling'])
    save_results(results_dir, loader.dataset.samples, np.array(all_labels),
                 np.array(all_predictions), all_seg_logits_interp,
                 all_seg_preds_interp, np.array(all_dices))

    metrics = metric_report(np.array(all_labels), all_probabilities.numpy(),
                            np.array(all_predictions))
    metrics['images_path'] = loader.dataset.samples
    metrics['labels'] = np.array(all_labels)
    metrics['logits'] = all_logits.numpy()
    metrics['probabilities'] = all_probabilities.numpy()
    metrics['predictions'] = np.array(all_predictions)
    metrics['losses'] = np.array(all_losses)

    metrics['dice_per_image'] = np.array(all_dices)
    metrics['mean_dice'] = metrics['dice_per_image'].mean()
    metrics['dice'] = evaluator.dice()[1].item()
    metrics['iou_per_image'] = np.array(all_ious)
    metrics['mean_iou'] = metrics['iou_per_image'].mean()
    metrics['iou'] = evaluator.intersection_over_union()[1].item()
    metrics['conf_mat'] = evaluator.cm.numpy()

    if ex.current_run.config['dataset'][
            'split'] == 0 and ex.current_run.config['dataset']['fold'] == 0:
        metrics['seg_preds'] = all_seg_preds_interp

    return metrics
示例#3
0
    def validate(self, epoch=None, is_test=False, use_train_mode=False):
        # 1. super network viterbi_decodde, get actual_path
        # 2. cells genotype decode, which are on the actual_path in the super network
        # 3. according to actual_path and cells genotypes, construct the best network.
        # 4. use the best network, to perform test phrase.

        if is_test:
            data_loader = self.run_config.test_loader
            epoch_str = None
            self.logger.log('\n' + '-' * 30 + 'TESTING PHASE' + '-' * 30 +
                            '\n',
                            mode='valid')
        else:
            data_loader = self.run_config.valid_loader
            epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1,
                                                      self.run_config.epochs)
            self.logger.log('\n' + '-' * 30 +
                            'Valid epoch: {:}'.format(epoch_str) + '-' * 30 +
                            '\n',
                            mode='valid')

        model = self.model

        if use_train_mode: model.train()
        else: model.eval()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        mious = AverageMeter()
        fscores = AverageMeter()
        accs = AverageMeter()
        end0 = time.time()

        with torch.no_grad():
            for i, (datas, targets) in enumerate(data_loader):
                if torch.cuda.is_available():
                    datas = datas.to(self.device, non_blocking=True)
                    targets = targets.to(self.device, non_blocking=True)
                else:
                    raise ValueError('do not support cpu version')
                data_time.update(time.time() - end0)
                # validation of the derived model. normal forward pass.
                logits = self.model(datas)
                loss = self.criterion(logits, targets)
                # metrics calculate and update
                evaluator = Evaluator(self.run_config.nb_classes)
                evaluator.add_batch(targets, logits)
                miou = evaluator.Mean_Intersection_over_Union()
                fscore = evaluator.Fx_Score()
                acc = evaluator.Pixel_Accuracy()
                losses.update(loss.data.item(), datas.size(0))
                mious.update(miou.item(), datas.size(0))
                fscores.update(fscore.item(), datas.size(0))
                accs.update(acc.item(), datas.size(0))
                # duration
                batch_time.update(time.time() - end0)
                end0 = time.time()

            if is_test:
                Wstr = '|*TEST*|' + time_string()
                Tstr = '|Time  | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base  | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                    loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'test')
            else:
                Wstr = '|*VALID*|' + time_string() + epoch_str
                Tstr = '|Time   | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base   | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                    loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'valid')

        return losses.avg, accs.avg, mious.avg, fscores.avg
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        #### if initializer
        if args.init is not None:
            model = DeepLab(num_classes=21,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)
        else:
            model = DeepLab(num_classes=self.nclass,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)

        # train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr
        }]

        # Define Optimizer
        # optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)

        optimizer = torch.optim.Adam(train_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     amsgrad=True)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        #self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                                    args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        #initializing network
        if args.init is not None:
            if not os.path.isfile(args.init):
                raise RuntimeError(
                    "=> no initializer checkpoint found at '{}'".format(
                        args.init))
            checkpoint = torch.load(args.init)
            #args.start_epoch = checkpoint['epoch']
            state_dict = checkpoint['state_dict']
            # del state_dict["decoder.last_conv.8.weight"]
            # del state_dict["decoder.last_conv.8.bias"]
            if args.cuda:
                self.model.module.load_state_dict(state_dict, strict=False)
            else:
                self.model.load_state_dict(state_dict, strict=False)
            # if not args.ft:
            #     self.optimizer.load_state_dict(checkpoint['optimizer'])
            # self.best_pred = checkpoint['best_pred']
            self.model.module.decoder.last_layer = nn.Conv2d(256,
                                                             self.nclass,
                                                             kernel_size=1,
                                                             stride=1).cuda()
            print("=> loaded initializer '{}' (epoch {})".format(
                args.init, checkpoint['epoch']))

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            #args.start_epoch = checkpoint['epoch']
            ##state_dict = checkpoint['state_dict']
            ## del state_dict["decoder.last_conv.8.weight"]
            ## del state_dict["decoder.last_conv.8.bias"]
            if args.cuda:
                #self.model.module.load_state_dict(state_dict, strict=False)
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                #self.model.load_state_dict(state_dict, strict=False)
                self.model.load_state_dict(checkpoint['state_dict'])
            # if not args.ft:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

            #self.model.module.decoder.last_layer = nn.Conv2d(256, self.nclass, kernel_size=1, stride=1)

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_loss1 = 0.0
        train_loss2 = 0.0
        train_loss3 = 0.0
        self.model.train()

        # trying to save a checkpoint and check if it exists...
        # import os
        # cur_path = os.path.dirname(os.path.abspath('.'))
        # print('saving mycheckpoint in:' + cur_path )
        # checkpoint_name = 'mycheckpoint.pth.tar'
        # save_path = cur_path + '/' + checkpoint_name
        # torch.save(self.model.module.state_dict(), save_path)
        # assert(os.path.isfile(save_path))
        # # torch.save(self.model.module.state_dict(), checkpoint_name)
        # # assert(os.path.isfile(cur_path + '/' + checkpoint_name))
        # print('checkpoint saved ok')
        # # checkpoint saved

        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)

        # import pdb; pdb.set_trace()
        # label_w = label_stats(self.train_loader, nimg=70) #** -args.norm_loss if args.norm_loss != 0 else None
        # import pdb; pdb.set_trace()

        for i, sample in enumerate(tbar):
            #print("i is:{}, index is:{}".format(i,sample['index']))
            #print("path is:{}".format(sample['path']))
            #image, target = sample['image'], sample['label']
            image, target, index, path, b_mask, enlarged_b_mask = sample[
                'image'], sample['label'], sample['index'], sample[
                    'path'], sample['b_mask'], sample['enlarged_b_mask']
            #print('sample for training index is :{} and path is:{}'.format(index,path))
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            #not using learning rate scheduler and apply a fixed learning rate
            #self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            #import pdb; pdb.set_trace()
            #loss = self.criterion(output, target, b_mask, enlarged_b_mask)
            loss1, loss2, loss3, loss = self.criterion(output, target, b_mask,
                                                       enlarged_b_mask)
            # criterion = nn.BCELoss()
            # loss = criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            train_loss1 += loss1.item()
            train_loss2 += loss2.item()
            train_loss3 += loss3.item()
            #import pdb; pdb.set_trace()
            tbar.set_description('Train loss_total: %.3f' % (train_loss /
                                                             (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
            # tbar.set_description('Train loss1: %.3f' % (train_loss1 / (i + 1)))
            self.writer.add_scalar('train/total_loss1_iter', loss1.item(),
                                   i + num_img_tr * epoch)
            # tbar.set_description('Train loss2: %.3f' % (train_loss2 / (i + 1)))
            self.writer.add_scalar('train/total_loss2_iter', loss2.item(),
                                   i + num_img_tr * epoch)
            # tbar.set_description('Train loss3: %.3f' % (train_loss3 / (i + 1)))
            self.writer.add_scalar('train/total_loss3_iter', loss3.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                #if i % (num_img_tr // 10000) == 0:   #for the whole dataset
                #if i % (num_img_tr // 10) == 0:    # for debugging
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output, b_mask,
                                             enlarged_b_mask, global_step)
            # Save the model after each 500 iterations
            if i % 500 == 0:  #for the whole dataset
                #if i % 5 == 0:    # for debugging
                is_best = False
                self.saver.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)

            # perform the validation after each 1000 iterations
            if i % 300 == 0:
                #if i % 1000 == 0 :  #for the whole dataset
                #if i % 15 == 0 :    # for debugging
                self.validation(i)
                #self.validation(i + num_img_tr * epoch)

            ## garbage collection pass
            #del image, target, index, path, b_mask, enlarged_b_mask, output
            #gc.collect()

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        self.writer.add_scalar('train/total_loss1_epoch', train_loss1, epoch)
        self.writer.add_scalar('train/total_loss2_epoch', train_loss2, epoch)
        self.writer.add_scalar('train/total_loss3_epoch', train_loss3, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            #print('saving checkpoint')
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        test_loss1 = 0.0
        test_loss2 = 0.0
        test_loss3 = 0.0
        for i, sample in enumerate(tbar):
            #image, target = sample['image'], sample['label']
            #image, target, index, path = sample['image'], sample['label'], sample['index'], sample['path']
            image, target, index, path, b_mask, enlarged_b_mask = sample[
                'image'], sample['label'], sample['index'], sample[
                    'path'], sample['b_mask'], sample['enlarged_b_mask']
            #print('sample for testing index is :{} and path is:{}'.format(index,path))
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            #loss = self.criterion(output, target)
            #loss = self.criterion(output, target, b_mask, enlarged_b_mask)
            loss1, loss2, loss3, loss = self.criterion(output, target, b_mask,
                                                       enlarged_b_mask)
            test_loss += loss.item()
            test_loss1 += loss1.item()
            test_loss2 += loss2.item()
            test_loss3 += loss3.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = np.argmax(target, axis=1)
            #pred = np.argmax(pred, axis=1)
            #import pdb; pdb.set_trace()
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/total_loss1_epoch', test_loss1, epoch)
        self.writer.add_scalar('val/total_loss2_epoch', test_loss2, epoch)
        self.writer.add_scalar('val/total_loss3_epoch', test_loss3, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
示例#5
0
class Test:
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 sample_number,
                 trial=100,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        self.sample_number = sample_number
        # load source domain
        #self.source_set = spacenet.Spacenet(city=config.dataset, split='test', img_root=config.img_root, needs to be changed)
        #self.source_loader = DataLoader(self.source_set, batch_size=16, shuffle=False, num_workers=2)
        self.source_loader = None
        self.save_path = save_path
        self.save_batch = save_batch
        self.trial = trial
        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city,
                                     split='val',
                                     img_root=config.img_root,
                                     gt_root=config.gt_root,
                                     mean_std=config.mean_std,
                                     if_augment=config.if_augment,
                                     repeat_count=config.repeat_count)
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))

            train = spacenet.Spacenet(city=city,
                                      split='train',
                                      img_root=config.img_root,
                                      gt_root=config.gt_root,
                                      mean_std=config.mean_std,
                                      if_augment=config.if_augment,
                                      repeat_count=config.repeat_count,
                                      sample_number=sample_number)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

    def get_performance(self, dataloader, trainloader, city, adabn_layer):
        # change mean and var of bn to adapt to the target domain
        dirname = os.path.join(self.save_path, city + '_bn')
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        batch = self.save_batch
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(dataloader, desc='\r')

        # save in different directories
        if self.bn:
            save_path = os.path.join(self.save_path, city + '_bn')
        else:
            save_path = os.path.join(self.save_path, city)

    # evaluate on the test dataset
        for i, sample in enumerate(tbar):
            image, target, path = sample['image'], sample['label'], sample[
                'path']
            pdb.set_trace()
            #print(path)
            if self.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # save pictures
            if batch > 0:
                if not os.path.exists(self.save_path):
                    os.mkdir(self.save_path)
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                image = image.cpu().numpy() * 255
                image = image.transpose(0, 2, 3, 1).astype(int)

                imgs = self.color_images(pred, target)
                self.save_images(imgs, batch, save_path, False)
                self.save_images(image, batch, save_path, True)
                batch -= 1

        Acc = self.evaluator.Building_Acc()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        return Acc, IoU, mIoU

    def test(self, name, converged_model):
        source_test = 0
        if source_test:
            A, I, Im = self.get_performance(self.source_loader, None,
                                            self.config.dataset)

        i_all = []
        l = 1  #in range(1,61,1):#(10,61,10):
        tA, tI, tIm = [], [], []

        for dl, tl, city in zip(self.target_loader, self.target_trainloader,
                                self.target):
            tA_, tI_, tIm_ = self.get_performance(dl, tl, city, l)

            tA.append(tA_)
            tI.append(tI_)
            tIm.append(tIm_)
        i_all.append(tI[0])
        pickle.dump((i_all, converged_model, self.sample_number),
                    open(name, 'wb'))
        pickle.dump((i_all, converged_model, self.sample_number),
                    open(name, 'wb'))

        res = {}
        if source_test:
            print("Test for source domain:")
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(self.config.dataset, A,
                                                       I, Im))
            res[config.dataset] = {'Acc': A, 'IoU': I, 'mIoU': Im}

        print('Test for target domain:')
        for i, city in enumerate(self.target):
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(city, tA[i], tI[i],
                                                       tIm[i]))
            res[city] = {'Acc': tA[i], 'IoU': tI[i], 'mIoU': tIm[i]}

        if self.bn:
            name = 'train_log/test_bn.json'
        else:
            name = 'train_log/test.json'

        with open(name, 'w') as f:
            json.dump(res, f)

    def save_images(self, imgs, batch_index, save_path, if_original=False):
        for i, img in enumerate(imgs):
            #img = img[:,:,::-1] # change to BGR
            #from IPython import embed
            #embed()
            if not if_original:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Original.jpg'),
                    img)
            else:
                img = img.astype(np.uint8)
                img = Image.fromarray(img)
                img.save(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Pred.jpg'))
                #cv2.imwrite(os.path.join(save_path, str(batch_index) + str(i) + '_Pred.jpg'), img)

    def color_images(self, pred, target):
        imgs = []
        for p, t in zip(pred, target):
            tmp = p * 2 + t
            np.squeeze(tmp)
            img = np.zeros((p.shape[0], p.shape[1], 3))
            # bkg:negative, building:postive
            #from IPython import embed
            #embed()
            img[np.where(tmp == 0)] = [0, 0, 0]  # Black RGB, for true negative
            img[np.where(tmp == 1)] = [255, 0,
                                       0]  # Red RGB, for false negative
            img[np.where(tmp == 2)] = [0, 255,
                                       0]  # Green RGB, for false positive
            img[np.where(tmp == 3)] = [255, 255,
                                       0]  #Yellow RGB, for true positive
            imgs.append(img)
        return imgs
示例#6
0
class Infer(object):
    def __init__(self,args):
        self.args = args
        self.nclass  = 4
        self.save_fold = 'brain_re/brain_cedice'
        mkdir(self.save_fold)
        self.name = self.save_fold.split('/')[-1].split('_')[-1]
        #===for brain==========================
        # self.nclass = 4
        # self.save_fold = 'brain_re'
        #======================================
        net = segModel(self.args,self.nclass)
        net.build_model()
        model = net.model
        #load params
        resume = args.resume
        self.model = torch.nn.DataParallel(model)
        self.model = self.model.cuda()
        print('==>Load model...')
        if not resume is None:
            checkpoint = torch.load(resume)
            # model.load_state_dict(checkpoint)
            model.load_state_dict(checkpoint['state_dict'])
        self.model = model
        print('==>loding loss func...')
        self.criterion = SegmentationLosses(cuda=args.cuda).build_loss(mode=args.loss_type)

        #define evaluator
        self.evaluator = Evaluator(self.nclass)

        #get data path
        root_path = Path.db_root_dir(self.args.dataset)
        if self.args.dataset == 'drive':
            folder = 'test'
            self.test_img = os.path.join(root_path, folder, 'images')
            self.test_label = os.path.join(root_path, folder, '1st_manual')
            self.test_mask = os.path.join(root_path, folder, 'mask')
        elif self.args.dataset == 'brain':
            path = root_path+'/Bra-pickle'
            valid_path = '../data/Brain/test.csv'
            self.valid_set = get_dataset(path,valid_path)
        print('loading test data...')

        #define data
        self.test_loader = None

    def eval(self):
        gt_name = os.listdir(self.test_label)
        img_list = [os.path.join(self.test_label, image) for image in gt_name]
        mask_listdir = [os.path.join(self.test_mask,image.split('.')[0].split('_')[0]+'_test_mask.gif') for image in gt_name]
        pred_list = get_result_list(gt_name,self.save_fold)
        #transform
        for i in range(len(img_list)):
            target, preds, _mask = img_list[i],pred_list[i],mask_listdir[i]
            self.evaluator.add_batch(target, preds,mask=_mask)
        #idx = len(img_list)
        idx=1
        test_Acc = self.evaluator.Pixel_Accuracy()
        test_acc_class = self.evaluator.Pixel_Accuracy_Class()
        test_mIou = self.evaluator.Mean_Intersection_over_Union()
        test_fwiou = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        pre,recall,auc=self.evaluator.show_Roc()
        print('Test:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, precision:{}, Recall:{}, Auc:{}".format(test_Acc / idx, test_acc_class / idx, test_mIou / idx,
                                                                test_fwiou / idx,pre,recall,auc))


    def predict_a_patch(self):
        self.model.eval()
        imgs = os.listdir(self.test_img)
        labels = []
        for i in imgs:
            label_name = (i.split('.')[0]).split('_')[0]+'_manual1.gif'
            labels.append(label_name)
        img_list = [os.path.join(self.test_img,image) for image in imgs]
        label_list = [os.path.join(self.test_label,lab) for lab in labels]

        #some params
        patch_h = self.args.ph
        patch_w = self.args.pw
        stride_h = self.args.sh
        stride_w = self.args.sw
        #crop imgs to patches
        images_patch, labels_patch, Height, Width,self.gray_original = extract_patches_test(img_list, label_list, patch_h, patch_w, stride_h,
                                                           stride_w)  # list[patches]
        data = []
        for i, j in zip(images_patch, labels_patch):
            data.append((i, j))

        #start test one batch has one image
        tbar = tqdm(data)
        for idx,sample in enumerate(tbar):
            image,target = sample[0],sample[1]
            #print(image.shape,target.shape)
            image,target = image.cuda(),target.cuda()
            with torch.no_grad():
                result = self._predict_a_patch(image)
            preds = result
            full_preds = merge_overlap(preds, Height, Width, stride_h, stride_w)  # Tensor->[1,1,H,W]
            full_preds = full_preds[0,1,:,:]
            full_img = tfs.ToPILImage()((full_preds*255).type(torch.uint8))
            full_image = (full_preds>=0.5)*1#0.5
            mergeImage = merge(self.gray_original[idx],full_image)
            #save result image
            name_probs = imgs[idx].split('.')[0].split('_')[0]+'_test_prob.bmp'
            name_merge = imgs[idx].split('.')[0].split('_')[0]+'_merge.bmp'
            save(mergeImage,os.path.join(self.save_fold,name_merge))
            save(full_img,os.path.join(self.save_fold,name_probs))

    def _predict_a_patch(self, patchs):
        number_of_patch = patchs.shape[0]
        results = torch.zeros(number_of_patch,self.nclass,
                            self.args.ph, self.args.pw)
        results = results.cuda()
        patchs = patchs.float()

        steps  = int(number_of_patch / self.args.batch_size)
        #step  = tqdm(steps)
        for i in range(steps):
            start_index = i*self.args.batch_size
            end_index   = start_index + self.args.batch_size
            output  = self.model( patchs[start_index:end_index] )
            output      = torch.sigmoid( output )
            results[start_index:end_index] = output
        results[end_index:] = torch.sigmoid(self.model(patchs[end_index:]))
        return results

    def test(self):
        self.model.eval()
        print(self.model)
        self.evaluator.reset()
        self.test_loader = DataLoader(self.valid_set,batch_size=self.args.test_batch_size,shuffle=False)
        tbar = tqdm(self.test_loader, desc='\r')#need to rewrite
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            #show(image[0].permute(1,2,0).numpy(),target[0].numpy())
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
                _,pred = output.max(1)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            target = target.cpu().numpy()
            pred = pred.cpu().numpy()

            #show
            if i >= 0 and i<=100:
                iii = image[0].cpu().numpy()
                showimg = np.transpose(iii,(1,2,0))
                plt.figure()
                plt.imshow(showimg,cmap='gray')
                plt.show()
                fname = self.save_fold+'/'+self.name+'_'+str(i)+'.png'
                show(image[0].permute(1,2,0).cpu().numpy(),target[0],pred[0],fname)
            # if i>99:
            #     break
            #save(pred[0],fname=str(i)+'.jpg')
            #

            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        Dice_coff = self.evaluator.DiceCoff()
        P, R, perclass = self.evaluator.compute_el()
        print('Test:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, dice:{}".format(Acc, Acc_class, mIoU, FWIoU,Dice_coff))
        print('precision:{},recall:{}'.format(P,R))
        print('preclass, pre{},recall:{}'.format(perclass[0],perclass[1]))
        print('Loss: %.3f' % test_loss)
示例#7
0
class Trainer(object):
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888)
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(
            config)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=config.backbone,
                        output_stride=config.out_stride,
                        sync_bn=config.sync_bn,
                        freeze_bn=config.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': config.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.model, self.optimizer = model, optimizer
        self.entropy_mini_loss = MinimizeEntropyLoss()
        self.batchloss = BatchLoss()

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        self.best_pred_source = 0.0
        self.best_pred_target = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

    def training(self, epoch):
        seg_loss_sum, bn_loss_sum, entropy_loss_sum = 0.0, 0.0, 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        target_train_iterator = iter(self.target_train_loader)
        for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            self.vis.line(X=torch.tensor([itr]),
                          Y=torch.tensor(
                              [self.optimizer.param_groups[0]['lr']]),
                          win='lr',
                          opts=dict(title='lr', xlabel='iter', ylabel='lr'),
                          update='append' if itr > 0 else None)
            A_image, A_target = sample['image'], sample['label']

            # Get one batch from target domain
            try:
                target_sample = next(target_train_iterator)
            except StopIteration:
                target_train_iterator = iter(self.target_train_loader)
                target_sample = next(target_train_iterator)

            B_image, B_target = target_sample['image'], target_sample['label']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()
                B_image, B_target = B_image.cuda(), B_target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target)
            # Supervised loss
            self.optimizer.zero_grad()
            #print(A_image.size())
            A_output = self.model(A_image)
            A_bn_mean, A_bn_var = self.model.module.get_bn_parameter()
            seg_loss = self.criterion(A_output, A_target)
            seg_loss.backward()
            self.optimizer.step()

            # Unsupervised bn loss
            self.optimizer.zero_grad()
            B_output = self.model(B_image)
            B_bn_mean, B_bn_var = self.model.module.get_bn_parameter()
            mean_loss = self.batchloss.loss(A_bn_mean, B_bn_mean)
            var_loss = self.batchloss.loss(A_bn_var, B_bn_var)
            bn_loss = mean_loss + var_loss
            bn_loss.requires_grad = True
            bn_loss.backward()
            self.optimizer.step()

            # Unsupervised entropy minimization loss
            self.optimizer.zero_grad()
            entropy_mini_loss = self.entropy_mini_loss.loss(B_output)
            entropy_mini_loss.backward()
            self.optimizer.step()

            seg_loss_sum += seg_loss.item()
            bn_loss_sum += bn_loss.item()
            entropy_loss_sum += entropy_mini_loss.item()

            tbar.set_description('Seg loss: %.3f' % (seg_loss / (i + 1)))

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        print('Loss: %.3f' % seg_loss)

        self.vis.line(X=torch.tensor([epoch]),
                      Y=torch.tensor([seg_loss_sum]),
                      win='train_loss',
                      name='Seg_loss',
                      opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
                      update='append' if epoch > 0 else None)
        self.vis.line(X=torch.tensor([epoch]),
                      Y=torch.tensor([bn_loss_sum]),
                      win='train_loss',
                      name='BN_loss',
                      opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
                      update='append' if epoch > 0 else None)
        self.vis.line(X=torch.tensor([epoch]),
                      Y=torch.tensor([entropy_loss_sum]),
                      win='train_loss',
                      name='Entropy_loss',
                      opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
                      update='append' if epoch > 0 else None)

    def validation(self, epoch):
        def get_metrics(tbar, if_source=False):
            test_loss = 0.0
            for i, sample in enumerate(tbar):
                image, target = sample['image'], sample['label']

                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()

                with torch.no_grad():
                    output = self.model(image)

                loss = self.criterion(output, target)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
                pred = output.data.cpu().numpy()
                target = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)

                # Add batch sample into evaluator
                self.evaluator.add_batch(target, pred)

            # Fast test during the training
            Acc = self.evaluator.Building_Acc()
            IoU = self.evaluator.Building_IoU()
            mIoU = self.evaluator.Mean_Intersection_over_Union()

            if if_source:
                print('Validation on source:')
            else:
                print('Validation on target:')
            print('[Epoch: %d, numImages: %5d]' %
                  (epoch, i * self.config.batch_size + image.data.shape[0]))
            print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU))
            print('Loss: %.3f' % test_loss)

            # Draw Visdom
            if if_source:
                names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
            else:
                names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']

            self.vis.line(X=torch.tensor([epoch]),
                          Y=torch.tensor([test_loss]),
                          win='val_loss',
                          name=names[0],
                          update='append')
            self.vis.line(X=torch.tensor([epoch]),
                          Y=torch.tensor([Acc]),
                          win='metrics',
                          name=names[1],
                          opts=dict(title='metrics',
                                    xlabel='epoch',
                                    ylabel='performance'),
                          update='append' if epoch > 0 else None)
            self.vis.line(X=torch.tensor([epoch]),
                          Y=torch.tensor([IoU]),
                          win='metrics',
                          name=names[2],
                          update='append')
            self.vis.line(X=torch.tensor([epoch]),
                          Y=torch.tensor([mIoU]),
                          win='metrics',
                          name=names[3],
                          update='append')

            return Acc, IoU, mIoU

        self.model.eval()
        self.evaluator.reset()
        tbar_source = tqdm(self.val_loader, desc='\r')
        tbar_target = tqdm(self.target_val_loader, desc='\r')
        s_acc, s_iou, s_miou = get_metrics(tbar_source, True)
        t_acc, t_iou, t_miou = get_metrics(tbar_target, False)

        new_pred_source = s_iou
        new_pred_target = t_iou
        if new_pred_source > self.best_pred_source or new_pred_target > self.best_pred_target:
            is_best = True
            self.best_pred_source = max(new_pred_source, self.best_pred_source)
            self.best_pred_target = max(new_pred_target, self.best_pred_target)
            print('Saving state, epoch:', epoch)
            torch.save(
                self.model.module.state_dict(), self.args.save_folder +
                'models/' + 'epoch' + str(epoch) + '.pth')
            loss_file = {
                's_Acc': s_acc,
                's_IoU': s_iou,
                's_mIoU': s_miou,
                't_Acc': t_acc,
                't_IoU': t_iou,
                't_mIoU': t_miou
            }
            with open(
                    os.path.join(self.args.save_folder, 'eval',
                                 'epoch' + str(epoch) + '.json'), 'w') as f:
                json.dump(loss_file, f)
示例#8
0
class Trainer(object):
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.visdom = args.visdom
        if args.visdom:
            self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888)
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(
            config)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)

        self.D = Discriminator(num_classes=self.nclass, ndf=16)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': config.lr * config.lr_ratio
        }]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)
        self.D_optimizer = torch.optim.Adam(self.D.parameters(),
                                            lr=config.lr,
                                            betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.entropy_mini_loss = MinimizeEntropyLoss()
        self.bottleneck_loss = BottleneckLoss()
        self.instance_loss = InstanceLoss()
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')
        # labels for adversarial training
        self.source_label = 0
        self.target_label = 1

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

            self.D = torch.nn.DataParallel(self.D)
            patch_replication_callback(self.D)
            self.D = self.D.cuda()

        self.best_pred_source = 0.0
        self.best_pred_target = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

    def training(self, epoch):
        train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum, adv_loss_sum, d_loss_sum, ins_loss_sum = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        self.model.train()
        if config.freeze_bn:
            self.model.module.freeze_bn()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        target_train_iterator = iter(self.target_train_loader)
        for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            #if self.visdom:
            #    self.vis.line(X=torch.tensor([itr]), Y=torch.tensor([self.optimizer.param_groups[0]['lr']]),
            #              win='lr', opts=dict(title='lr', xlabel='iter', ylabel='lr'),
            #              update='append' if itr>0 else None)
            self.summary.writer.add_scalar(
                'Train/lr', self.optimizer.param_groups[0]['lr'], itr)
            A_image, A_target = sample['image'], sample['label']

            # Get one batch from target domain
            try:
                target_sample = next(target_train_iterator)
            except StopIteration:
                target_train_iterator = iter(self.target_train_loader)
                target_sample = next(target_train_iterator)

            B_image, B_target, B_image_pair = target_sample[
                'image'], target_sample['label'], target_sample['image_pair']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()
                B_image, B_target, B_image_pair = B_image.cuda(
                ), B_target.cuda(), B_image_pair.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target, self.config.lr_ratio)
            self.scheduler(self.D_optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target, self.config.lr_ratio)

            A_output, A_feat, A_low_feat = self.model(A_image)
            B_output, B_feat, B_low_feat = self.model(B_image)
            B_output_pair, B_feat_pair, B_low_feat_pair = self.model(
                B_image_pair)
            B_output_pair, B_feat_pair, B_low_feat_pair = flip(
                B_output_pair, dim=-1), flip(B_feat_pair,
                                             dim=-1), flip(B_low_feat_pair,
                                                           dim=-1)

            self.optimizer.zero_grad()
            self.D_optimizer.zero_grad()

            # Train seg network
            for param in self.D.parameters():
                param.requires_grad = False

            # Supervised loss
            seg_loss = self.criterion(A_output, A_target)
            main_loss = seg_loss

            # Unsupervised loss
            ins_loss = 0.01 * self.instance_loss(B_output, B_output_pair)
            main_loss += ins_loss

            # Train adversarial loss
            D_out = self.D(prob_2_entropy(F.softmax(B_output)))
            adv_loss = bce_loss(D_out, self.source_label)

            main_loss += self.config.lambda_adv * adv_loss
            main_loss.backward()

            # Train discriminator
            for param in self.D.parameters():
                param.requires_grad = True
            A_output_detach = A_output.detach()
            B_output_detach = B_output.detach()
            # source
            D_source = self.D(prob_2_entropy(F.softmax(A_output_detach)))
            source_loss = bce_loss(D_source, self.source_label)
            source_loss = source_loss / 2
            # target
            D_target = self.D(prob_2_entropy(F.softmax(B_output_detach)))
            target_loss = bce_loss(D_target, self.target_label)
            target_loss = target_loss / 2
            d_loss = source_loss + target_loss
            d_loss.backward()

            self.optimizer.step()
            self.D_optimizer.step()

            seg_loss_sum += seg_loss.item()
            ins_loss_sum += ins_loss.item()
            adv_loss_sum += self.config.lambda_adv * adv_loss.item()
            d_loss_sum += d_loss.item()

            #train_loss += seg_loss.item() + self.config.lambda_adv * adv_loss.item()
            train_loss += seg_loss.item()
            self.summary.writer.add_scalar('Train/SegLoss', seg_loss.item(),
                                           itr)
            self.summary.writer.add_scalar('Train/InsLoss', ins_loss.item(),
                                           itr)
            self.summary.writer.add_scalar('Train/AdvLoss', adv_loss.item(),
                                           itr)
            self.summary.writer.add_scalar('Train/DiscriminatorLoss',
                                           d_loss.item(), itr)
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            # Show the results of the last iteration
            #if i == len(self.train_loader)-1:
        print("Add Train images at epoch" + str(epoch))
        self.summary.visualize_image('Train-Source', self.config.dataset,
                                     A_image, A_target, A_output, epoch, 5)
        self.summary.visualize_image('Train-Target', self.config.target,
                                     B_image, B_target, B_output, epoch, 5)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)
        #print('Seg Loss: %.3f' % seg_loss_sum)
        #print('Ins Loss: %.3f' % ins_loss_sum)
        #print('BN Loss: %.3f' % bn_loss_sum)
        #print('Adv Loss: %.3f' % adv_loss_sum)
        #print('Discriminator Loss: %.3f' % d_loss_sum)

        #if self.visdom:
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([seg_loss_sum]), win='train_loss', name='Seg_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([ins_loss_sum]), win='train_loss', name='Ins_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([bn_loss_sum]), win='train_loss', name='BN_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss_sum]), win='train_loss', name='Adv_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([d_loss_sum]), win='train_loss', name='Dis_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)

    def validation(self, epoch):
        def get_metrics(tbar, if_source=False):
            self.evaluator.reset()
            test_loss = 0.0
            #feat_mean, low_feat_mean, feat_var, low_feat_var = 0, 0, 0, 0
            #adv_loss = 0.0
            for i, sample in enumerate(tbar):
                image, target = sample['image'], sample['label']

                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()

                with torch.no_grad():
                    output, low_feat, feat = self.model(image)

                #low_feat = low_feat.cpu().numpy()
                #feat = feat.cpu().numpy()

                #if isinstance(feat, np.ndarray):
                #    feat_mean += feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    low_feat_mean += low_feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    feat_var += feat.var(axis=0).var(axis=1).var(axis=1)
                #    low_feat_var += low_feat.var(axis=0).var(axis=1).var(axis=1)
                #else:
                #    feat_mean = feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    low_feat_mean = low_feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    feat_var = feat.var(axis=0).var(axis=1).var(axis=1)
                #    low_feat_var = low_feat.var(axis=0).var(axis=1).var(axis=1)

                #d_output = self.D(prob_2_entropy(F.softmax(output)))
                #adv_loss += bce_loss(d_output, self.source_label).item()
                loss = self.criterion(output, target)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
                pred = output.data.cpu().numpy()

                target_ = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)

                # Add batch sample into evaluator
                self.evaluator.add_batch(target_, pred)
            if if_source:
                print("Add Validation-Source images at epoch" + str(epoch))
                self.summary.visualize_image('Val-Source', self.config.dataset,
                                             image, target, output, epoch, 5)
            else:
                print("Add Validation-Target images at epoch" + str(epoch))
                self.summary.visualize_image('Val-Target', self.config.target,
                                             image, target, output, epoch, 5)
            #feat_mean /= (i+1)
            #low_feat_mean /= (i+1)
            #feat_var /= (i+1)
            #low_feat_var /= (i+1)
            #adv_loss /= (i+1)
            # Fast test during the training
            Acc = self.evaluator.Building_Acc()
            IoU = self.evaluator.Building_IoU()
            mIoU = self.evaluator.Mean_Intersection_over_Union()

            if if_source:
                print('Validation on source:')
            else:
                print('Validation on target:')
            print('[Epoch: %d, numImages: %5d]' %
                  (epoch, i * self.config.batch_size + image.data.shape[0]))
            print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU))
            print('Loss: %.3f' % test_loss)

            if if_source:
                names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
                self.summary.writer.add_scalar('Val/SourceAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/SourceIoU', IoU, epoch)
            else:
                names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']
                self.summary.writer.add_scalar('Val/TargetAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/TargetIoU', IoU, epoch)
            # Draw Visdom
            #if if_source:
            #    names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
            #else:
            #    names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']

            #if self.visdom:
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([test_loss]), win='val_loss', name=names[0],
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss]), win='val_loss', name='adv_loss',
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([Acc]), win='metrics', name=names[1],
            #                  opts=dict(title='metrics', xlabel='epoch', ylabel='performance'),
            #                  update='append' if epoch > 0 else None)
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([IoU]), win='metrics', name=names[2],
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([mIoU]), win='metrics', name=names[3],
            #                  update='append')

            return Acc, IoU, mIoU

        self.model.eval()
        tbar_source = tqdm(self.val_loader, desc='\r')
        tbar_target = tqdm(self.target_val_loader, desc='\r')
        s_acc, s_iou, s_miou = get_metrics(tbar_source, True)
        t_acc, t_iou, t_miou = get_metrics(tbar_target, False)

        new_pred_source = s_iou
        new_pred_target = t_iou

        if new_pred_source > self.best_pred_source or new_pred_target > self.best_pred_target:
            is_best = True
            self.best_pred_source = max(new_pred_source, self.best_pred_source)
            self.best_pred_target = max(new_pred_target, self.best_pred_target)
        print('Saving state, epoch:', epoch)
        torch.save(
            self.model.module.state_dict(),
            self.args.save_folder + 'models/' + 'epoch' + str(epoch) + '.pth')
        loss_file = {
            's_Acc': s_acc,
            's_IoU': s_iou,
            's_mIoU': s_miou,
            't_Acc': t_acc,
            't_IoU': t_iou,
            't_mIoU': t_miou
        }
        with open(
                os.path.join(self.args.save_folder, 'eval',
                             'epoch' + str(epoch) + '.json'), 'w') as f:
            json.dump(loss_file, f)
示例#9
0
class operater(object):
    def __init__(self, args, model,trn_loader,val_loader,chk_loader,optimizer):

        self.args = args
        self.model=model
        self.train_loader = trn_loader
        self.val_loader = val_loader
        self.chk_loader = chk_loader
        self.optimizer = optimizer
        # Define Evaluator
        self.evaluator = Evaluator(args.nclass)
        # Define lr scheduler
        # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                          args.epochs, len(trn_loader))
        self.scheduler=torch.optim.lr_scheduler.MultiStepLR(self.optimizer,milestones=[3,6,9], gamma=0.5)
        self.wait_epoches=10
        self.best_pred = 0
        self.init_weight=0.98
        # Define Saver
        self.saver = Saver(self.args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.evaluator = Evaluator(self.args.nclass)

    def training(self,epoch,args):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        #w1 = 0.2 + 0.5 * (self.init_weight - 0.2) * (1 + np.cos(epoch * np.pi / args.epochs))
        print('Learning rate:', self.optimizer.param_groups[0]['lr'])
        for i, (x1,x2,y,index) in enumerate(tbar):
            x1=Variable(x1)
            x2=Variable(x2)
            #y_cls=Seg2cls(args,y)#图像级标签,N,1,1,C
            if self.args.cuda:
                x1, x2 = x1.cuda(),x2.cuda()
            #self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(x1,x2)

            output = F.softmax(output, dim=1)

            # loss_ce=CELossLayer(self.args,output,y)
            # #print('ce loss', loss_ce)
            #
            # loss_focal = FocalLossLayer(self.args,output, y)
            #print('focal loss', loss_focal)

            loss_lovasz = LovaszLossLayer(output,y)
            #print('lovasz loss', loss_lovasz)

            # self.writer.add_scalar('train/ce_loss_iter', loss_focal.item(), i + num_img_tr * epoch)
            # self.writer.add_scalar('train/focal_loss_iter', loss_focal.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/lovasz_loss_iter', loss_lovasz.item(), i + num_img_tr * epoch)

            #loss = w1 * loss_ce + (0.5 - 0.5 * w1) * loss_focal + (0.5 - 0.5 * w1) * loss_lovasz

            loss = loss_lovasz

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            #Show 10 * 3 inference results each epoch
            if i % 10 == 0:
                global_step = i + num_img_tr * epoch
                if self.args.oly_s1 and not self.args.oly_s2:
                    self.summary.visualize_image(self.writer, self.args.dataset, x1[:,[0],:,:], y, output, global_step)
                elif not self.args.oly_s1:
                    if self.args.rgb:
                        self.summary.visualize_image(self.writer, self.args.dataset, x2, y, output, global_step)
                    else:
                        self.summary.visualize_image(self.writer, self.args.dataset, x2[:,[2,1,0],:,:], y,output,global_step)
                else:
                    raise NotImplementedError
        self.scheduler.step()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + y.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
            }, is_best)

    def validation(self,epoch, args):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        for i, (x1,x2,y,index) in enumerate(tbar):

            if self.args.cuda:
                x1, x2 = x1.cuda(),x2.cuda()
            with torch.no_grad():
                output = self.model(x1, x2)

            pred = output.data.cpu().numpy()
            pred[:,[2,7],:,:]=0
            target = y[:,0,:,:].cpu().numpy()  # batch_size * 256 * 256
            pred = np.argmax(pred, axis=1)  # batch_size * 256 * 256
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        OA = self.evaluator.Pixel_Accuracy()
        AA = self.evaluator.val_Pixel_Accuracy_Class()
        self.writer.add_scalar('val/OA', OA, epoch)
        self.writer.add_scalar('val/AA', AA, epoch)

        print('AVERAGE ACCURACY:', AA)

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + y.data.shape[0]))

        new_pred = AA
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
示例#10
0
class ExperimentBuilder(nn.Module):
    def __init__(self, network_model, num_class, experiment_name, num_epochs, train_data, val_data, test_data, learn_rate, mementum, weight_decay, use_gpu, continue_from_epoch=-1):
        super(ExperimentBuilder, self).__init__()

        self.experiment_name = experiment_name
        self.model = network_model
        #self.model.reset_parameters()
        self.num_class = num_class
        self.learn_rate = learn_rate
        self.mementum = mementum
        self.weight_decay = weight_decay

        if(self.experiment_name == "test"):
            print('Testmode')
        
        if torch.cuda.device_count() > 1 and use_gpu:
            self.device = torch.cuda.current_device()
            self.model.to(self.device)
            self.model = nn.DataParallel(module=self.model)
            print('Use Mutil GPU', self.device)
            print('GPU number', torch.cuda.device_count())
        elif torch.cuda.device_count() == 1 and use_gpu:
            self.device = torch.cuda.current_device()
            self.model.to(self.device)  
            print('Use GPU', self.device)
        else:
            self.device = torch.device('cpu')  
            print('Use CPU', self.device)

        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data 
        self.num_epochs = num_epochs

        train_params = [{'params': network_model.get_backbone_params(), 'lr': self.learn_rate},
                        {'params': network_model.get_classifier_params(), 'lr': self.learn_rate * 1}]
        self.optimizer = torch.optim.SGD(train_params, momentum=self.mementum, weight_decay=self.weight_decay)
        #self.criterion = FocalLoss(ignore_index=255, size_average=True).to(self.device)
        self.criterion = CrossEntropyLoss(size_average=True, ignore_index=255).to(self.device)
        self.scheduler = PolyLR(self.optimizer, max_iters=self.num_epochs*len(self.train_data), power=0.9)
        self.evaluator = Evaluator(self.num_class)
     
        total_num_params = 0
        for param in network_model.parameters():
            total_num_params += np.prod(param.shape)
        print('System learnable parameters', total_num_params)
        num_conv_layers = 0
        for name, value in self.named_parameters():
            if all(item in name for item in ['conv', 'weight']):
                num_conv_layers += 1
        print('number of conv layers', num_conv_layers)
        

        self.experiment_folder = os.path.abspath(experiment_name)
        self.experiment_logs = os.path.abspath(os.path.join(self.experiment_folder, "result_outputs"))
        self.experiment_saved_models = os.path.abspath(os.path.join(self.experiment_folder, "saved_models"))
        print(self.experiment_folder, self.experiment_logs)
        self.best_val_model_idx = 0
        self.best_val_model_acc = 0.

        if not os.path.exists(self.experiment_folder): 
            os.mkdir(self.experiment_folder) 

        if not os.path.exists(self.experiment_logs):
            os.mkdir(self.experiment_logs)  

        if not os.path.exists(self.experiment_saved_models):
            os.mkdir(self.experiment_saved_models) 
        
        if continue_from_epoch == -2:
            try:
                self.best_val_model_idx, self.best_val_model_acc, self.state = self.load_model(
                    model_save_dir=self.experiment_saved_models, model_save_name="train_model",
                    model_idx='latest')  
                self.starting_epoch = self.state['current_epoch_idx'] + 1
                #self.scheduler.step()
                self.scheduler.last_epoch = self.state['last_epoch']
                print("restart from epoch ",self.state['current_epoch_idx'])
                print("backbone learning rate: ", self.optimizer.param_groups[0]['lr'])
                print("classifier learning rate: ", self.optimizer.param_groups[1]['lr'])
                print("iterations: ", self.scheduler.last_epoch)
                print("base_lr:", self.scheduler.base_lrs)
                
            except:
                print("Model objects cannot be found, initializing a new model and starting from scratch")
                self.starting_epoch = 0
                self.state = dict()

        elif continue_from_epoch != -1:  
            self.best_val_model_idx, self.best_val_model_acc, self.state = self.load_model(
                model_save_dir=self.experiment_saved_models, model_save_name="train_model",
                model_idx=continue_from_epoch)  
            self.starting_epoch = self.state['current_epoch_idx'] + 1
            self.scheduler.step()
            self.scheduler.last_epoch = self.state['last_epoch']
            print("restart from epoch ",self.state['current_epoch_idx'])
            print("backbone learning rate: ", self.optimizer.param_groups[0]['lr'])
            print("classifier learning rate: ", self.optimizer.param_groups[1]['lr'])
            print("iterations: ", self.scheduler.last_epoch)
            print("base_lr:", self.scheduler.base_lrs)
        else:
            self.starting_epoch = 0
            self.state = dict()

    def run_train_iter(self, image, target):
        
        self.train()
        self.evaluator.reset()
        image = image.to(self.device)
        target = target.to(self.device)
        
        self.optimizer.zero_grad()
        output = self.model.forward(image)
        loss = self.criterion(output, target.long())
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()

        predicted = output.data.cpu().numpy()
        target = target.cpu().numpy()
        predicted = np.argmax(predicted, axis=1)

        self.evaluator.add_batch(target, predicted)
        miou = self.evaluator.Mean_Intersection_over_Union()
        acc = self.evaluator.Pixel_Accuracy()
        return loss.data.detach().cpu().numpy(), miou, acc
    
    def run_evaluation_iter(self, image, target):
        
        self.eval()
        self.evaluator.reset()
        image = image.to(self.device)
        target = target.to(self.device)
        
        output = self.model.forward(image)
        loss = self.criterion(output, target.long())

        predicted = output.data.cpu().numpy()
        target = target.cpu().numpy()
        predicted = np.argmax(predicted, axis=1)

        self.evaluator.add_batch(target, predicted)
        miou = self.evaluator.Mean_Intersection_over_Union()
        acc = self.evaluator.Pixel_Accuracy()
        return loss.data.detach().cpu().numpy(), miou, acc
   
    def run_predicted_iter(self, image, target):

        self.eval()
        self.evaluator.reset()
        image = image.to(self.device)
        target = target.to(self.device)

        output = self.model.forward(image)
        loss = self.criterion(output, target.long())

        predicted = output.data.cpu().numpy()
        target = target.cpu().numpy()
        predicted = np.argmax(predicted, axis=1)

        self.evaluator.add_batch(target, predicted)
        miou = self.evaluator.Mean_Intersection_over_Union()
        acc = self.evaluator.Pixel_Accuracy()
        return predicted

    def save_model(self, model_save_dir, model_save_name, model_idx, state):
        
        state['network'] = self.model.state_dict()
        state['optimizer'] = self.optimizer.state_dict()
        #state['scheduler'] = self.scheduler.state_dict()
        state['last_epoch'] = self.scheduler.last_epoch
        torch.save(state, f=os.path.join(model_save_dir, "{}_{}".format(model_save_name, str(model_idx))))

    def run_training_epoch(self, current_epoch_losses):
        with tqdm.tqdm(total=len(self.train_data), file=sys.stdout) as pbar_train:  
            for idx, (image, target) in enumerate(self.train_data): 
                loss, miou, acc = self.run_train_iter(image, target)  
                current_epoch_losses["train_loss"].append(loss)  
                current_epoch_losses["train_miou"].append(miou)
                current_epoch_losses["train_acc"].append(acc)  
                pbar_train.update(1)
                if(torch.cuda.device_count() >= 1):
                    m = torch.cuda.get_device_properties(0).total_memory/1e9
                    c = torch.cuda.max_memory_cached(0)/1e9
                    a = torch.cuda.max_memory_allocated(0)/1e9
                    pbar_train.set_description("Training: loss: {:.4f}, miou: {:.4f}, Pacc: {:.4f}, memory: {:.2f}GB, cached:{:.2f}GB, allocated:{:.2f}GB".format(loss, miou, acc, m, c, a))
                else:
                    pbar_train.set_description("Training: loss: {:.4f}, miou: {:.4f}, Pacc: {:.4f}".format(loss, miou, acc))
                    

        return current_epoch_losses
    
    def run_validation_epoch(self, current_epoch_losses):
        with tqdm.tqdm(total=len(self.val_data), file=sys.stdout) as pbar_val:  
            for idx, (image, target) in enumerate(self.val_data): 
                loss, miou, acc = self.run_evaluation_iter(image, target)  
                current_epoch_losses["val_loss"].append(loss) 
                current_epoch_losses["val_miou"].append(miou)
                current_epoch_losses["val_acc"].append(acc)  
                pbar_val.update(1)
                pbar_val.set_description("Validating: loss: {:.4f}, miou: {:.4f}, Pacc: {:.4f}".format(loss, miou, acc))

        return current_epoch_losses
    
    def run_testing_epoch(self, current_epoch_losses):
        with tqdm.tqdm(total=len(self.test_data), file=sys.stdout) as pbar_test:  
            for idx, (image, target) in enumerate(self.test_data): 
                loss, miou, acc = self.run_evaluation_iter(image, target)  
                current_epoch_losses["test_loss"].append(loss) 
                current_epoch_losses["test_miou"].append(miou)  
                current_epoch_losses["test_acc"].append(acc)
                pbar_test.update(1)
                pbar_test.set_description("Testing: loss: {:.4f}, miou: {:.4f}, Pacc: {:.4f}".format(loss, miou, acc))

        return current_epoch_losses

    def load_model(self, model_save_dir, model_save_name, model_idx):
        
        state = torch.load(f=os.path.join(model_save_dir, "{}_{}".format(model_save_name, str(model_idx))))
        self.model.load_state_dict(state_dict=state['network'])
        self.optimizer.load_state_dict(state_dict=state['optimizer'])
        #self.scheduler.load_state_dict(state_dict=state['scheduler'])
        self.scheduler.last_epoch = state['last_epoch']
        return state['best_val_model_idx'], state['best_val_model_acc'], state
    
    def run_experiment(self):
        total_losses = {"train_miou": [], "train_acc": [], "train_loss": [], "val_miou": [], "val_acc": [],
                        "val_loss": [], "curr_epoch": []}  
        for i, epoch_idx in enumerate(range(self.starting_epoch, self.num_epochs)):
            epoch_start_time = time.time()
            current_epoch_losses = {"train_miou": [], "train_acc": [], "train_loss": [],"val_miou": [], "val_acc": [], "val_loss": []}

            current_epoch_losses = self.run_training_epoch(current_epoch_losses)
            #print(self.optimizer.param_groups[0]['lr'])
            current_epoch_losses = self.run_validation_epoch(current_epoch_losses)

            val_mean_miou = np.mean(current_epoch_losses['val_miou'])
            if val_mean_miou > self.best_val_model_acc:  
                self.best_val_model_acc = val_mean_miou  
                self.best_val_model_idx = epoch_idx  

            for key, value in current_epoch_losses.items():
                total_losses[key].append(np.mean(value))

            total_losses['curr_epoch'].append(epoch_idx)
            save_statistics(experiment_log_dir=self.experiment_logs, filename='summary.csv',
                            stats_dict=total_losses, current_epoch=i,
                            continue_from_mode=True if (self.starting_epoch != 0 or i > 0) else False) 

            out_string = "_".join(
                ["{}_{:.4f}".format(key, np.mean(value)) for key, value in current_epoch_losses.items()])
            epoch_elapsed_time = time.time() - epoch_start_time  
            epoch_elapsed_time = "{:.4f}".format(epoch_elapsed_time)
            print("Epoch {}:".format(epoch_idx),"Iteration {}:".format(self.scheduler.last_epoch), out_string, "epoch time", epoch_elapsed_time, "seconds")
            self.state['current_epoch_idx'] = epoch_idx
            self.state['best_val_model_acc'] = self.best_val_model_acc
            self.state['best_val_model_idx'] = self.best_val_model_idx
            if(self.experiment_name != "test"):
                #if(epoch_idx==0 or (epoch_idx+1)%10==0):
                #    self.save_model(model_save_dir=self.experiment_saved_models,
                #            model_save_name="train_model", model_idx=epoch_idx, state=self.state)
                self.save_model(model_save_dir=self.experiment_saved_models,
                            model_save_name="train_model", model_idx='latest', state=self.state)
            
        if(self.experiment_name != "test"):
            print("Generating test set evaluation metrics")
            self.load_model(model_save_dir=self.experiment_saved_models, model_idx='latest',
                            model_save_name="train_model")
            current_epoch_losses = {"test_miou": [], "test_acc": [], "test_loss": []}  

            current_epoch_losses = self.run_testing_epoch(current_epoch_losses=current_epoch_losses)

            test_losses = {key: [np.mean(value)] for key, value in
                       current_epoch_losses.items()}  

            save_statistics(experiment_log_dir=self.experiment_logs, filename='test_summary.csv',
                        stats_dict=test_losses, current_epoch=0, continue_from_mode=False)
        else:
            test_losses = 0
        return total_losses, test_losses 
示例#11
0
class Trainer(object):
    def __init__(self, config, args):
        self.args = args
        self.config = config
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(config)
        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                        backbone=config.backbone,
                        output_stride=config.out_stride,
                        sync_bn=config.sync_bn,
                        freeze_bn=config.freeze_bn)


        train_params = [{'params': self.model.get_1x_lr_params(), 'lr': config.lr},
                        {'params': self.model.get_10x_lr_params(), 'lr': config.lr * config.lr_ratio}]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params, momentum=config.momentum,
                                    weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.consistency = ConsistencyLoss(cuda=args.cuda)
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr,
                                      config.epochs, len(self.train_loader),
                                      config.lr_step, config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        self.best_pred_source = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint, map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, args.start_epoch))

    def training(self, epoch):
        train_loss, seg_loss_sum, consistency_loss_sum = 0.0, 0.0, 0.0
        self.model.train()
        if config.freeze_bn:
            self.model.module.freeze_bn()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        target_train_iterator = iter(self.target_train_loader)
	for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            self.summary.writer.add_scalar('Train/lr', self.optimizer.param_groups[0]['lr'], itr)
            A_image, A_target = sample['image'], sample['label']
            # Get one batch from target domain
            try:
                target_sample = next(target_train_iterator)
            except StopIteration:
                target_train_iterator = iter(self.target_train_loader)
                target_sample = next(target_train_iterator)

            B_image, B_target, B_image_pair = target_sample['image'], target_sample['label'], target_sample['image_pair']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()
                B_image, B_target, B_image_pair = B_image.cuda(), B_target.cuda(), B_image_pair.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source, 0., self.config.lr_ratio)

            A_output, A_feat, A_low_feat = self.model(A_image)
            B_output, B_feat, B_low_feat = self.model(B_image)
            B_output_pair, B_feat_pair, B_low_feat_pair = self.model(B_image_pair)

            self.optimizer.zero_grad()

            # Train seg network
            # Supervised loss
            seg_loss = self.criterion(A_output, A_target)
            main_loss = seg_loss

            # Consistency loss
            consistency_loss = 0.01 * self.ConsistencyLoss(B_output, B_output_pair)
            main_loss += consistency_loss

            main_loss.backward()

            self.optimizer.step()

            seg_loss_sum += seg_loss.item()
            consistency_loss_sum += consistency_loss.item()
            train_loss += seg_loss.item()
            self.summary.writer.add_scalar('Train/SegLoss', seg_loss.item(), itr)
            self.summary.writer.add_scalar('Train/ConsistencyLoss', consistency_loss.item(), itr)
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            # Show the results of the last iteration
            #if i == len(self.train_loader)-1:
        print("Add Train images at epoch"+str(epoch))
        self.summary.visualize_image('Train-Source', self.config.dataset, A_image, A_target, A_output, epoch, 5)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

    def validation(self, epoch):
        def get_metrics(tbar, if_source=False):
            self.evaluator.reset()
            test_loss = 0.0
            for i, sample in enumerate(tbar):
                image, target = sample['image'], sample['label']

                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()

                with torch.no_grad():
                    output, low_feat, feat = self.model(image)


                loss = self.criterion(output, target)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
                pred = output.data.cpu().numpy()

                target_ = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)

                # Add batch sample into evaluator
                self.evaluator.add_batch(target_, pred)
            if if_source:
                print("Add Validation-Source images at epoch"+str(epoch))
                self.summary.visualize_image('Val-Source', self.config.dataset, image, target, output, epoch, 5)
            else:
                print("Add Validation-Target images at epoch"+str(epoch))
                self.summary.visualize_image('Val-Target', self.config.target, image, target, output, epoch, 5)
            # Fast test during the training
            Acc = self.evaluator.Building_Acc()
            IoU = self.evaluator.Building_IoU()
            mIoU = self.evaluator.Mean_Intersection_over_Union()

            if if_source:
                print('Validation on source:')
            else:
                print('Validation on target:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + image.data.shape[0]))
            print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU))
            print('Loss: %.3f' % test_loss)

            if if_source:
                names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
                self.summary.writer.add_scalar('Val/SourceAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/SourceIoU', IoU, epoch)
            else:
                names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']
                self.summary.writer.add_scalar('Val/TargetAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/TargetIoU', IoU, epoch)

            return Acc, IoU, mIoU

        self.model.eval()
        tbar_source = tqdm(self.val_loader, desc='\r')
        s_acc, s_iou, s_miou = get_metrics(tbar_source, True)

        new_pred_source = s_iou

        if new_pred_source > self.best_pred_source:
            is_best = True
            self.best_pred_source = max(new_pred_source, self.best_pred_source)
        print('Saving state, epoch:', epoch)
        torch.save(self.model.module.state_dict(), self.args.save_folder + 'models/'
                    + 'epoch' + str(epoch) + '.pth')
        loss_file = {'s_Acc': s_acc, 's_IoU': s_iou, 's_mIoU': s_miou}
        with open(os.path.join(self.args.save_folder, 'eval', 'epoch' + str(epoch) + '.json'), 'w') as f:
            json.dump(loss_file, f)
示例#12
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = MyDeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        self.model = model
        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        #optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)

        # adam
        optimizer = torch.optim.Adam(params=self.model.parameters(),betas=(0.9, 0.999),
                                    eps=1e-08, weight_decay=0, amsgrad=False)

        weight = [1, 10, 10, 10, 10, 10, 10, 10]
        weight = torch.tensor(weight, dtype=torch.float) 
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda, num_classes=self.nclass).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
        
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        '''
        # 获取当前模型各层的名称
        layer_name = list(self.model.state_dict().keys())
        #print(self.model.state_dict()[layer_name[3]])
        # 加载通用的预训练模型
        pretrained = './pretrained_model/deeplab-mobilenet.pth.tar'
        pre_ckpt = torch.load(pretrained)
        key_name = list(checkpoint['state_dict'].keys()) # 获取预训练模型各层的名称
        pre_ckpt['state_dict'][key_name[-2]] = checkpoint['state_dict'][key_name[-2]] # 类别不同,最后两层单独赋值
        pre_ckpt['state_dict'][key_name[-1]] = checkpoint['state_dict'][key_name[-1]]
        self.model.module.load_state_dict(pre_ckpt['state_dict'])     # , strict=False)
        #print(self.model.state_dict()[layer_name[3]])
        print("加载预训练模型ok")
        '''
    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            #import pdb
            #pdb.set_trace()
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            #if (i+1) % 50 == 0:
            #    print('Train loss: %.3f' % (loss.item() / (i + 1)))
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        filename='checkpoint_{}_{:.4f}.pth.tar'.format(epoch, train_loss)
        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best, filename=filename)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            #if (i+1) %20 == 0:
            #    print('Test loss: %.3f' % (loss / (i + 1)))
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
示例#13
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader1, self.train_loader2, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass,
                            12,
                            self.criterion,
                            crop_size=self.args.crop_size,
                            lambda_latency=self.args.lambda_latency)
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.model, self.optimizer = model, optimizer

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Define Optimizer

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        self.evaluator_device = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader1))
        self.architect = Architect(self.model, args)

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def fetch_arch(self):
        d = dict()
        d['alphas_cell'] = self.model.arch_parameters()[0]
        d['alphas_network'] = self.model.arch_parameters()[1]
        d['alphas_distributed'] = self.model.arch_parameters()[2]
        return d

    def training(self, epoch):
        train_la, train_loss = 0.0, 0.0
        self.model.train()
        tbar = tqdm(self.train_loader1)
        tbar1 = tqdm(self.train_loader1)
        tbar2 = tqdm(self.train_loader1)
        num_img_tr = len(self.train_loader1)
        for i, sample in enumerate(tbar):

            image, target = sample['image'], sample['label']
            search = next(iter(self.train_loader2))
            image_search, target_search = search['image'], search['label']
            # print ('------------------------begin-----------------------')
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                #image_search, target_search = image_search.cuda (), target_search.cuda ()
                # print ('cuda finish')
            #if epoch>=20:
            #self.architect.step (image_search, target_search)
        #   if i%20==0:
        #       print(self.model.arch_parameters()[2])
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()

            output, device_output, loss, la, c_loss, d_loss = self.model._loss(
                image, target)
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(
                    self.writer, self.args.dataset, image, target,
                    output * 0.5 + device_output * 0.5, global_step)
            loss.backward()
            self.optimizer.step()
            if epoch >= 20:
                image_search, target_search = image_search.cuda(
                ), target_search.cuda()
                self.architect.step(image_search, target_search)
            train_la += la
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f   Train latency: %.3f' %
                                 (train_loss / (i + 1), train_la / (i + 1)))
            tbar2.set_description(
                'cloud loss:: %.3f   device loss:: %.3f latence loss:: %3f' %
                (c_loss, d_loss, la))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
            self.writer.add_scalar('train/cloud_loss_iter', c_loss.item(),
                                   i + num_img_tr * epoch)
            self.writer.add_scalar('train/device_loss_iter', d_loss.item(),
                                   i + num_img_tr * epoch)
            self.writer.add_scalar('train/latency_loss_iter', la.item(),
                                   i + num_img_tr * epoch)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        self.writer.add_scalar('train/latency_loss_epoch', train_la, epoch)
        arch_board = self.model.arch_parameters()[0]
        for i in range(len(arch_board)):
            for j in range(len(arch_board[i])):
                self.writer.add_scalar('cell/' + str(i) + '/' + str(j),
                                       arch_board[i][j], epoch)

        arch_board = self.model.arch_parameters()[1]
        for i in range(len(arch_board)):
            for j in range(len(arch_board[i])):
                for k in range(len(arch_board[i][j])):
                    self.writer.add_scalar(
                        'network/' + str(i) + str(j) + str(k),
                        arch_board[i][j][k], epoch)

        arch_board = self.model.arch_parameters()[2]
        for i in range(len(arch_board)):
            self.writer.add_scalar('distributed/' + str(i), arch_board[i],
                                   epoch)

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch_para': self.fetch_arch(),
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        self.evaluator_device.reset()

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output, device_output, loss, _, _, _ = self.model._loss(
                    image, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred_device = device_output.data.cpu().numpy()

            pred = np.argmax(pred, axis=1)
            pred_device = np.argmax(pred_device, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            self.evaluator_device.add_batch(target, pred_device)
        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU_device = self.evaluator_device.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU_cloud', mIoU, epoch)
        self.writer.add_scalar('val/mIoU_device', mIoU_device, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
示例#14
0
class trainNew(object):
    def __init__(self, args):
        self.args = args

        """ Define Saver """
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        """ Define Tensorboard Summary """
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = self.args.use_amp
        self.opt_level = self.args.opt_level

        """ Define Dataloader """
        kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
         
        if args.network == 'searched_dense':
            """ 40_5e_lr_38_31.91  """
            # cell_path_1 = os.path.join(args.saved_arch_path, '40_5e_38_lr', 'genotype_1.npy')
            # cell_path_2 = os.path.join(args.saved_arch_path, '40_5e_38_lr','genotype_2.npy')
            # cell_arch_1 = np.load(cell_path_1)
            # cell_arch_2 = np.load(cell_path_2)
            # network_arch = [1, 2, 3, 2, 3, 2, 2, 1, 2, 1, 1, 2]

            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy')
            cell_arch = np.load(cell_path)
            network_arch = [0, 1, 2, 3, 2, 2, 2, 2, 1, 2, 3, 2]
            low_level_layer = 0

            model = Model_2(network_arch,
                            cell_arch,
                            self.nclass,
                            args,
                            low_level_layer)

        elif args.network == 'searched_baseline':
            cell_path_1 = os.path.join(args.saved_arch_path, 'searched_baseline', 'genotype_1.npy')
            cell_path_2 = os.path.join(args.saved_arch_path, 'searched_baseline','genotype_2.npy')
            cell_arch_1 = np.load(cell_path_1)
            cell_arch_2 = np.load(cell_path_2)
            network_arch = [0, 1, 2, 2, 3, 2, 2, 1, 2, 1, 1, 2]
            low_level_layer = 1
            model = Model_2_baseline(network_arch,
                                        cell_arch,
                                        self.nclass,
                                        args,
                                        low_level_layer)

        elif args.network.startswith('autodeeplab'):
            network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy')
            cell_arch = np.load(cell_path)
            low_level_layer = 2

            if args.network == 'autodeeplab-dense':
                model = Model_2(network_arch,
                                        cell_arch,
                                        self.nclass,
                                        args,
                                        low_level_layer)

            elif args.network == 'autodeeplab-baseline':
                model = Model_2_baseline(network_arch,
                                        cell_arch,
                                        self.nclass,
                                        args,
                                        low_level_layer)


        """ Define Optimizer """
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        """ Define Criterion """
        """ whether to use class balanced weights """
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255).cuda()
        self.model, self.optimizer = model, optimizer

        """ Define Evaluator """
        self.evaluator_1 = Evaluator(self.nclass)
        self.evaluator_2 = Evaluator(self.nclass)

        """ Define lr scheduler """
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        if args.cuda:
            self.model = self.model.cuda()

        """ mixed precision """
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None

            """ fix for current pytorch version with opt_level 'O1' """
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) or isinstance(module, SynchronizedBatchNorm2d):
                        """ Hack to fix BN fprop without affine transformation """
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape, dtype=module.running_var.dtype,
                                           device=module.running_var.device), requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape, dtype=module.running_var.dtype,
                                            device=module.running_var.device), requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=self.opt_level,
                keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic")


        if args.cuda and len(self.args.gpu_ids) >1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')


        """ Resuming checkpoint """
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            """ if the weights are wrapped in module object we have to clean it """
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if (torch.cuda.device_count() > 1):
                    copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])
                else:
                    copy_state_dict(self.model.state_dict(), checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        """ Clear start epoch if fine-tuning """
        if args.ft:
            args.start_epoch = 0


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()

            output_1, output_2 = self.model(image)
            loss_1 = self.criterion(output_1, target)
            loss_2 = self.criterion(output_2, target)
            loss = loss_1 + loss_2
            
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            if i % 50 == 0:
                tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator_1.reset()
        self.evaluator_2.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output_1, output_2 = self.model(image)
            loss_1 = self.criterion(output_1, target)
            loss_2 = self.criterion(output_2, target)
            loss = loss_1 + loss_2
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))

            target_show = target
            pred_1 = torch.argmax(output_1, axis=1)
            pred_2 = torch.argmax(output_2, axis=1)

            """ Add batch sample into evaluator """
            self.evaluator_1.add_batch(target, pred_1)
            self.evaluator_2.add_batch(target, pred_2)
            if epoch//100 == i:
                global_step = epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target_show, output_2, global_step)

        mIoU_1 = self.evaluator_1.Mean_Intersection_over_Union()
        mIoU_2 = self.evaluator_2.Mean_Intersection_over_Union()

        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/classifier_1/mIoU', mIoU_1, epoch)
        self.writer.add_scalar('val/classifier_2/mIoU', mIoU_2, epoch)

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.test_batch_size + image.data.shape[0]))
        print("classifier_1_mIoU:{}, classifier_2_mIoU: {}".format(mIoU_1, mIoU_2))
        print('Loss: %.3f' % test_loss)

        new_pred = (mIoU_1 + mIoU_2)/2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
示例#15
0
## Run on Test set for acc
# print('Testing model:')
evaluator.reset()

for i, sample in enumerate(val_loader):
    batch, target = sample['image'], sample['label']
    if args.cuda:
        batch = batch.cuda()
    input = Variable(batch)
    output = model(input)
    pred = output.data.cpu().numpy()
    target = target.cpu().numpy()
    pred = np.argmax(pred, axis=1)
    # Add batch sample into evaluator
    evaluator.add_batch(target, pred)

Acc = evaluator.Pixel_Accuracy() * 100
Acc_class = evaluator.Pixel_Accuracy_Class() * 100
mIoU = evaluator.Mean_Intersection_over_Union() * 100
FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() * 100
print("Acc:{0:.3f}, Acc_class:{1:.3f}, mIoU:{2:.3f}, fwIoU: {3:.3f}".format(
    Acc, Acc_class, mIoU, FWIoU))

# model_acc[curr_iter] = Acc
# model_acc_class[curr_iter] = Acc_class
# model_miou[curr_iter] = mIoU

title = 'Inference Time of Pruned Model'
plt.plot(total_infer_times, '-b', label='infer_times')
plt.plot(model_acc, '-r', label='accuracy')
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        if args.dataset == 'CamVid':
            size = 512
            train_file = os.path.join(os.getcwd() + "\\data\\CamVid",
                                      "train.csv")
            val_file = os.path.join(os.getcwd() + "\\data\\CamVid", "val.csv")
            print('=>loading datasets')
            train_data = CamVidDataset(csv_file=train_file, phase='train')
            self.train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.num_workers)
            val_data = CamVidDataset(csv_file=val_file,
                                     phase='val',
                                     flip_rate=0)
            self.val_loader = torch.utils.data.DataLoader(
                val_data,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.num_workers)
            self.num_class = 32
        elif args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(
                args, **kwargs)
        elif args.dataset == 'NYUDv2':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.num_class = make_data_loader(
                args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2, 4, 23, 3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(),
                                         lr=args.lr,
                                         momentum=0,
                                         weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(),
                                        lr=args.lr,
                                        momentum=0,
                                        weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'CamVid':
            self.criterion = nn.CrossEntropyLoss()
        elif args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(
                weight=weight, cuda=args.cuda).build_loss(mode='ce')
        elif args.dataset == 'NYUDv2':
            weight = None
            self.criterion = SegmentationLosses(
                weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume:
            output_dir = os.path.join(args.save_dir, args.dataset,
                                      args.checkname)
            runs = sorted(glob.glob(os.path.join(output_dir, 'experiment_*')))
            run_id = int(runs[-1].split('_')[-1]) - 1 if runs else 0
            experiment_dir = os.path.join(output_dir,
                                          'experiment_{}'.format(str(run_id)))
            load_name = os.path.join(experiment_dir, 'checkpoint.pth.tar')
            if not os.path.isfile(load_name):
                raise RuntimeError(
                    "=> no checkpoint found at '{}'".format(load_name))
            checkpoint = torch.load(load_name)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.lr = checkpoint['optimizer']['param_groups'][0]['lr']
            print("=> loaded checkpoint '{}'(epoch {})".format(
                load_name, checkpoint['epoch']))

        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        # tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        if self.lr_staget_ind > 1 and epoch % (
                self.lr_stage[self.lr_staget_ind]) == 0:
            adjust_learning_rate(self.optimizer, self.args.lr_decay_gamma)
            self.lr *= self.args.lr_decay_gamma
            self.lr_staget_ind += 1
        for iteration, batch in enumerate(self.train_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            elif self.args.dataset == 'NYUDv2':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            inputs = Variable(image)
            labels = Variable(target)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels.long())
            loss_val = loss.item()
            loss.backward(torch.ones_like(loss))
            # loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            # tbar.set_description('\rTrain loss:%.3f' % (train_loss / (iteration + 1)))

            if iteration % 10 == 0:
                print("Epoch[{}]({}/{}):Loss:{:.4f}, learning rate={}".format(
                    epoch, iteration, len(self.train_loader), loss.data,
                    self.lr))

            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   iteration + num_img_tr * epoch)

            #if iteration % (num_img_tr // 10) == 0:
            #    global_step = iteration + num_img_tr * epoch
            #    self.summary.visualize_image(self.witer, self.args.dataset, image, target, outputs, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for iter, batch in enumerate(self.val_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            elif self.args.dataset == 'NYUDv2':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f ' % (test_loss / (iter + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/FWIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, iter * self.args.batch_size + image.shape[0]))
        print("Acc:{:.5f}, Acc_class:{:.5f}, mIoU:{:.5f}, fwIoU:{:.5f}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
示例#17
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Dataloader
        if args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2,4,23,3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()


        # Resuming checkpoint
        self.best_pred = 0.0
        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0 

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_loader)
        if self.lr_staget_ind > 1 and epoch % (self.lr_stage[self.lr_staget_ind]) == 0:
            adjust_learning_rate(self.optimizer, self.args.lr_decay_gamma)
            self.lr *= self.args.lr_decay_gamma
            self.lr_staget_ind += 1
        for iteration, batch in enumerate(self.train_loader):
            if self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            inputs = Variable(image)
            labels = Variable(target)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels.long())
            loss_val = loss.item()
            loss.backward(torch.ones_like(loss))
            self.optimizer.step()
            train_loss += loss.item()

            if iteration % 10 == 0:
                print("Epoch[{}]({}/{}):Loss:{:.4f}, learning rate={}".format(epoch, iteration, len(self.train_loader), loss.data, self.lr))
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0
        for iter, batch in enumerate(self.val_loader):
            if self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            print('Test Loss:%.3f' % (test_loss/(iter+1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iter * self.args.batch_size + image.shape[0]))
        print("Acc:{:.5f}, Acc_class:{:.5f}, mIoU:{:.5f}, fwIoU:{:.5f}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
示例#18
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define network
        model = DeepLab(num_classes=args.num_classes,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        self.model = model
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        _, self.valid_loader = make_data_loader(args, **kwargs)
        self.pred_remap = args.pred_remap
        self.gt_remap = args.gt_remap

        # Define Evaluator
        self.evaluator = Evaluator(args.eval_num_classes)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

    def validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.valid_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)

            if self.gt_remap is not None:
                target = self.gt_remap[target.astype(int)]
            if self.pred_remap is not None:
                pred = self.pred_remap[pred.astype(int)]
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
示例#19
0
def train_seg(model,
              dataLoader,
              epoch,
              optimizer,
              loss_fn,
              num_classes,
              logger,
              tensorLogger,
              device='cuda',
              args=None):
    model.train()
    logger.info("Train | [{:2d}/{}] | Lr: {}  |".format(
        epoch + 1, args.max_epoch, optimizer.param_groups[0]["lr"]))
    tensorLogger.add_scalar("Common/lr", optimizer.param_groups[0]["lr"],
                            epoch)
    losses = AverageMeter()
    batch_time = AverageMeter()
    Miou = AverageMeter()

    evaluator = Evaluator(num_class=num_classes)
    evaluator.reset()

    lossList = []
    miouList = []
    for i, (inputs, target) in enumerate(dataLoader):
        inputs = inputs.to(device=device)
        target = target.to(device=device)

        initTime = time.time()
        output = model(inputs)

        loss = loss_fn(output, target)

        output_np = output.detach().cpu().numpy()
        target_np = target.detach().cpu().numpy()

        # print(output_np.shape, target_np.shape)
        evaluator.add_batch(target_np, np.argmax(output_np, axis=1))
        losses.update(loss.item(), inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - initTime)
        if i % 20 == 0:
            miou, iou = evaluator.Mean_Intersection_over_Union()
            Miou.update(miou, 20)
            tensorLogger.add_scalar('train/loss', losses.avg,
                                    epoch * len(dataLoader) + i)
            tensorLogger.add_scalar('train/miou', miou,
                                    epoch * len(dataLoader) + i)
            lossList.append(losses.avg)
            miouList.append(miou)

        if i % 100 == 0:  # print after every 100 batches
            logger.info(
                "Train | {:2d} | [{:4d}/{}] Infer:{:.2f}sec |  Loss:{:.4f}  |  Miou:{:4f}  |"
                .format(epoch + 1, i + 1, len(dataLoader), batch_time.avg,
                        losses.avg, miou))
            evaluator.reset()

    return lossList, miouList
示例#20
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        if not args.test:
            self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        if self.args.norm == "gn":
            norm = gn
        elif self.args.norm == "bn":
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == "abn":
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        # Todo: add option for other networks
        model = LaneDeepLab(args=self.args,
                            num_classes=self.nclass,
                            freeze_bn=args.freeze_bn)
        """
        model.cuda()
        summary(model, input_size=(3, 720, 1280))
        exit()
        """

        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + "_classes_weights.npy")
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.ft:
                args.start_epoch = 0
            else:
                args.start_epoch = checkpoint["epoch"]

            if args.cuda:
                # self.model.module.load_state_dict(checkpoint['state_dict'])
                pretrained_dict = checkpoint["state_dict"]
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                # self.model.load_state_dict(checkpoint['state_dict'])
                pretrained_dict = checkpoint["state_dict"]
                model_dict = {}
                state_dict = self.model.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.load_state_dict(state_dict)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.best_pred = checkpoint["best_pred"]
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))
        elif args.decoder is not None:
            if not os.path.isfile(args.decoder):
                raise RuntimeError(
                    "=> no checkpoint for decoder found at '{}'".format(
                        args.decoder))
            checkpoint = torch.load(args.decoder)
            args.start_epoch = (
                0  # As every time loads decoder only should be finetuning
            )
            if args.cuda:
                decoder_dict = checkpoint["state_dict"]
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in decoder_dict.items():
                    if not "aspp" in k:
                        continue
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                raise NotImplementedError("Please USE CUDA!!!")

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample["image"], sample["label"]
            lanes = sample["lanes"]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                lanes = lanes.cuda().unsqueeze(1)
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, (target, lanes))
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description("Train loss: %.3f" % (train_loss / (i + 1)))
            continue
            # self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            """
            if i % (num_img_tr // 10) == 0 and False:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
            """

        self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch)
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Loss: %.3f" % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.module.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "best_pred": self.best_pred,
                },
                is_best,
            )

    def validation(self, epoch, inference=False):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample["image"], sample["label"]
            lanes = sample["lanes"]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                lanes = lanes.cuda().unsqueeze(1)
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, (target, lanes))
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            pred = output[:, :-1, :, :].data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch)
        self.writer.add_scalar("val/mIoU", mIoU, epoch)
        self.writer.add_scalar("val/Acc", Acc, epoch)
        self.writer.add_scalar("val/Acc_class", Acc_class, epoch)
        self.writer.add_scalar("val/fwIoU", FWIoU, epoch)
        print("Validation:")
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print("Loss: %.3f" % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.module.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "best_pred": self.best_pred,
                },
                is_best,
            )
示例#21
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.vs = Vs(args.dataset)

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        if self.args.norm == "gn":
            norm = gn
        elif self.args.norm == "bn":
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == "abn":
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        model = LaneDeepLab(args=args, num_classes=self.nclass)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset), args.dataset + "_classes_weights.npy"
            )
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(
                    args.dataset, self.train_loader, self.nclass
                )
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(
            mode=args.loss_type
        )
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])
            self.best_pred = checkpoint["best_pred"]
            print(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                )
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def test(self):
        self.model.eval()
        self.args.examine = False
        tbar = tqdm(self.test_loader, desc="\r")
        if self.args.color:
            __image = True
        else:
            __image = False
        for i, sample in enumerate(tbar):
            images = sample["image"]
            names = sample["name"]
            if self.args.cuda:
                images = images.cuda()
            with torch.no_grad():
                output = self.model(images)
            preds = output.data.cpu().numpy()
            preds = np.argmax(preds, axis=1)
            if __image:
                images = images.cpu().numpy()
            if not self.args.color:
                self.vs.predict_id(preds, names, self.args.save_dir)
            else:
                self.vs.predict_color(preds, images, names, self.args.save_dir)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        if self.args.color or self.args.examine:
            __image = True
        else:
            __image = False
        for i, sample in enumerate(tbar):
            images, targets = sample["image"], sample["label"]
            names = sample["name"]
            lanes = sample["lanes"]
            if self.args.cuda:
                images, targets = images.cuda(), targets.cuda()
                lanes = lanes.cuda().unsqueeze(1)
            with torch.no_grad():
                output = self.model(images)
            # debug:
            # print("output.shape: ", output.shape)
            # print("targets.shape: ", targets.shape)
            loss = self.criterion(output, (targets, lanes))
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            preds = output.data.cpu().numpy()
            targets = targets.cpu().numpy()
            preds = np.argmax(
                preds[:, 0:-1, :, :], axis=1
            )  # since the last channel is lane lines
            # Add batch sample into evaluator
            # debug:
            # print("targets.shape: ", targets.shape)
            # print("preds.shape: ", preds.shape)
            self.evaluator.add_batch(targets, preds)  # originla
            if __image:
                images = images.cpu().numpy()
            if self.args.id:
                self.vs.predict_id(preds, names, self.args.save_dir)
            if self.args.color:
                self.vs.predict_color(preds, images, names, self.args.save_dir)
            if self.args.examine:
                self.vs.predict_examine(
                    preds, targets, images, names, self.args.save_dir
                )

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Validation:")
        # print(
        #     "[Epoch: %d, numImages: %5d]"
        #     % (epoch, i * self.args.batch_size + image.data.shape[0])
        # )
        print(
            "Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
                Acc, Acc_class, mIoU, FWIoU
            )
        )
        print("Loss: %.3f" % test_loss)
示例#22
0
class ArchitectureSearcher(object):
    def __init__(self, args):
        self.args = args

        #Define Saver
        self.saver = Saver(args)
        #call saver function in which it is created a file
        #where informations train (like dataset,epoch..) are saved
        self.saver.save_experiment_config()

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        ##TODO: capire cosa è
        weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        model = AutoDeeplab(self.nclass, 10, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)

        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        ##TODO:capire cosa è
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)

        self.model = self.model.cuda()

        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                # self.model.load_state_dict(checkpoint['state_dict'])
                copy_state_dict(self.model.state_dict(),
                                checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        if args.resume is not None:
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()

            #plt.imshow(image[0].permute(2,1,0))
            #plt.show()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            #reset w.grad for each required_grad=True parameter
            self.optimizer.zero_grad()
            #compute mask prediction for image extracted from datasetA
            output = self.model(image)
            #compute lossA(Segmentation loss) between output and target
            loss = self.criterion(output, target)
            #compute loss grad respect to required_grad true parameter
            #and store the value inside x.grad
            loss.backward()
            #update w nn parameter(which are bounded with optimizer) using w.grad
            self.optimizer.step()

            if epoch >= self.args.alpha_epoch:
                search = next(iter(self.train_loaderB))
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda(
                    ), target_search.cuda()

                #reset alpha&beta.grad for each required_grad=True parameter
                self.architect_optimizer.zero_grad()
                #comput mask prediction for image extracted from datasetB
                output_search = self.model(image_search)
                #compute lossB(Segmentation loss) between output and target
                arch_loss = self.criterion(output_search, target_search)
                #compute loss grad respect to required_grad true parameter
                #and store the value inside alpha&beta.grad
                arch_loss.backward()
                #update alpha&beta nn parameter(which are bounded with optimizer) using alpha&beta.grad
                self.architect_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred

            #state.dict() let to save, update, alter and restore Pytorch model and optimazer
            state_dict = self.model.state_dict()
            #save checkpoint to disk, in this Saver method model_best.pth is created
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
示例#23
0
class Trainer(object):
    def __init__(self, config, args):
        self.args = args
        self.config = config

        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=config.backbone,
                        output_stride=config.out_stride,
                        sync_bn=config.sync_bn,
                        freeze_bn=config.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': config.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            #cudnn.benchmark = True
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Building_Acc()
        #Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        #FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + image.data.shape[0]))
        print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            print('Saving state, epoch:', epoch)
            torch.save(
                self.model.module.state_dict(), self.args.save_folder +
                'models/' + 'epoch' + str(epoch) + '.pth')
            loss_file = {'Acc': Acc, 'IoU': IoU, 'mIoU': mIoU}
            with open(
                    os.path.join(self.args.save_folder, 'eval',
                                 'epoch' + str(epoch) + '.json'), 'w') as f:
                json.dump(loss_file, f)
示例#24
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
示例#25
0
    def warm_up(self, warmup_epochs):
        if warmup_epochs <= 0:
            self.logger.log('=> warmup close', mode='warm')
            #print('\twarmup close')
            return
        # set optimizer and scheduler in warm_up phase
        lr_max = self.arch_search_config.warmup_lr
        data_loader = self.run_manager.run_config.train_loader
        scheduler_params = self.run_manager.run_config.optimizer_config[
            'scheduler_params']
        optimizer_params = self.run_manager.run_config.optimizer_config[
            'optimizer_params']
        momentum, nesterov, weight_decay = optimizer_params[
            'momentum'], optimizer_params['nesterov'], optimizer_params[
                'weight_decay']
        eta_min = scheduler_params['eta_min']
        optimizer_warmup = torch.optim.SGD(self.net.weight_parameters(),
                                           lr_max,
                                           momentum,
                                           weight_decay=weight_decay,
                                           nesterov=nesterov)
        # set initial_learning_rate in weight_optimizer
        #for param_groups in self.run_manager.optimizer.param_groups:
        #    param_groups['lr'] = lr_max
        lr_scheduler_warmup = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_warmup, warmup_epochs, eta_min)
        iter_per_epoch = len(data_loader)
        total_iteration = warmup_epochs * iter_per_epoch

        self.logger.log('=> warmup begin', mode='warm')

        # TODO: in inverse warm_up, inverse network_arch_param and all cell_arch_params firstly.
        self.net.set_inverse_weight()

        epoch_time = AverageMeter()
        end_epoch = time.time()
        for epoch in range(self.warmup_epoch, warmup_epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Warmup Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='warm')

            lr_scheduler_warmup.step(epoch)
            warmup_lr = lr_scheduler_warmup.get_lr()
            self.net.train()

            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            valid_losses = AverageMeter()
            valid_accs = AverageMeter()
            valid_mious = AverageMeter()
            valid_fscores = AverageMeter()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1, warmup_epochs)
            time_left = epoch_time.average * (warmup_epochs - epoch)
            common_log = '[Warmup the {:}] Left={:} LR={:}'.format(
                epoch_str,
                str(timedelta(seconds=time_left)) if epoch != 0 else None,
                warmup_lr)
            self.logger.log(common_log, mode='warm')
            end = time.time()

            # single_path init
            _, network_index = self.net.get_network_arch_hardwts_with_constraint(
            )
            _, aspp_index = self.net.get_aspp_hardwts_index()
            single_path = self.net.sample_single_path(
                self.run_manager.run_config.nb_layers, aspp_index,
                network_index)

            for i, (datas, targets) in enumerate(data_loader):
                #print(i)
                #print(self.net.single_path)
                #if i == 59: # used for debug
                #    break
                if torch.cuda.is_available():
                    datas = datas.to(self.run_manager.device,
                                     non_blocking=True)
                    targets = targets.to(self.run_manager.device,
                                         non_blocking=True)
                else:
                    raise ValueError('do not support cpu version')
                data_time.update(time.time() - end)

                # TODO: update one architecture sufficiently
                # 1. get hardwts and index
                # 2. sample single_path, and set single_path
                # 3. get arch_sample_frequency
                # 4. update single_path per '{:}'.format(sample_arch_frequency) frequency
                '''
                if (i+1) % self.arch_search_config.sample_arch_frequency == 0:
                    _, network_index = self.net.get_network_arch_hardwts_with_constraint()
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(self.run_manager.run_config.nb_layers, aspp_index, network_index)
                    '''
                logits = self.net.single_path_forward(datas, single_path)
                ce_loss = self.run_manager.criterion(logits, targets)
                entropy_reg = self.net.calculate_entropy(single_path)
                loss = self.run_manager.add_regularization_loss(
                    ce_loss, entropy_reg)
                # measure metrics and update
                evaluator = Evaluator(self.run_manager.run_config.nb_classes)
                evaluator.add_batch(targets, logits)
                acc = evaluator.Pixel_Accuracy()
                miou = evaluator.Mean_Intersection_over_Union()
                fscore = evaluator.Fx_Score()
                losses.update(loss.data.item(), datas.size(0))
                accs.update(acc, datas.size(0))
                mious.update(miou, datas.size(0))
                fscores.update(fscore, datas.size(0))

                self.net.zero_grad()
                loss.backward()
                self.run_manager.optimizer.step()

                if (i + 1
                    ) % self.arch_search_config.sample_arch_frequency == 0 or (
                        i + 1) == iter_per_epoch:
                    valid_datas, valid_targets = self.run_manager.run_config.valid_next_batch
                    if torch.cuda.is_available():
                        valid_datas = valid_datas.to(self.run_manager.device,
                                                     non_blocking=True)
                        valid_targets = valid_targets.to(
                            self.run_manager.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')

                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )  # set self.hardwts again
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(
                        valid_datas, single_path)

                    ce_loss = self.run_manager.criterion(logits, valid_targets)
                    entropy_reg = self.net.calculate_entropy(single_path)
                    loss = self.run_manager.add_regularization_loss(
                        ce_loss, entropy_reg)

                    # metrics and update
                    valid_evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    valid_evaluator.add_batch(valid_targets, logits)
                    acc = valid_evaluator.Pixel_Accuracy()
                    miou = valid_evaluator.Mean_Intersection_over_Union()
                    fscore = valid_evaluator.Fx_Score()
                    valid_losses.update(loss.data.item(), datas.size(0))
                    valid_accs.update(acc.item(), datas.size(0))
                    valid_mious.update(miou.item(), datas.size(0))
                    valid_fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()  # release computational graph
                    self.arch_optimizer.step()

                batch_time.update(time.time() - end)
                end = time.time()
                if (
                        i + 1
                ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                    Wstr = '|*WARM-UP*|' + time_string(
                    ) + '[{:}][iter{:03d}/{:03d}]'.format(
                        epoch_str, i + 1, iter_per_epoch)
                    Tstr = '|Time     | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                        batch_time=batch_time, data_time=data_time)
                    Bstr = '|Base     | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                        loss=losses, acc=accs, miou=mious, fscore=fscores)
                    Astr = '|Arch     | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                        loss=valid_losses,
                        acc=valid_accs,
                        miou=valid_mious,
                        fscore=valid_fscores)
                    self.logger.log(
                        Wstr + '\n' + Tstr + '\n' + Bstr + '\n' + Astr, 'warm')

            #torch.cuda.empty_cache()
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()
            '''
            # TODO: wheter perform validation after each epoch in warmup phase ?
            valid_loss, valid_acc, valid_miou, valid_fscore = self.validate()
            valid_log = 'Warmup Valid\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tF\t{5:6.4f}'\
                .format(epoch+1, warmup_epochs, valid_loss, valid_acc, valid_miou, valid_fscore)
                        #'\tflops\t{6:}M\tparams{7:}M'\
            valid_log += 'Train\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tFscore\t{5:6.4f}'
            self.run_manager.write_log(valid_log, 'valid')
            '''

            # continue warmup phrase
            self.warmup = epoch + 1 < warmup_epochs
            self.warmup_epoch = self.warmup_epoch + 1
            #self.start_epoch = self.warmup_epoch
            # To save checkpoint in warmup phase at specific frequency.

            # TODO: in inverse_alpha_warm_up, should check semantics of resume checkpoint.

            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch + 1) == warmup_epochs:
                state_dict = self.net.state_dict()
                # rm architecture parameters because, in warm_up phase, arch_parameters are not updated.
                #for key in list(state_dict.keys()):
                #    if 'cell_arch_parameters' in key or 'network_arch_parameters' in key or 'aspp_arch_parameters' in key:
                #        state_dict.pop(key)
                checkpoint = {
                    'state_dict': state_dict,
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.optimizer.state_dict(),
                    'arch_optimizer': self.arch_optimizer.state_dict(),
                    'warmup': self.warmup,
                    'warmup_epoch': epoch + 1,
                }
                filename = self.logger.path(mode='warm', is_best=False)
                save_path = save_checkpoint(checkpoint,
                                            filename,
                                            self.logger,
                                            mode='warm')
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define weight
        self.temporal_weight = args.temporal_weight
        self.spatial_weight = args.spatial_weight

        # Define network
        temporal_model = Model(name='vgg16_bn', num_classes=101,
                               is_flow=True).get_model()
        spatial_model = Model(name='vgg16_bn', num_classes=101,
                              is_flow=False).get_model()

        # Define Optimizer
        #optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        temporal_optimizer = torch.optim.Adam(temporal_model.parameters(),
                                              lr=args.temporal_lr)
        spatial_optimizer = torch.optim.Adam(spatial_model.parameters(),
                                             lr=args.spatial_lr)

        # Define Criterion
        self.temporal_criterion = nn.BCELoss().cuda()
        self.spatial_criterion = nn.BCELoss().cuda()

        self.temporal_model, self.temporal_optimizer = temporal_model, temporal_optimizer
        self.spatial_model, self.spatial_optimizer = spatial_model, spatial_optimizer

        # Define Evaluator
        self.top1_eval = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.temporal_model = torch.nn.DataParallel(
                self.temporal_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.temporal_model)
            self.temporal_model = self.temporal_model.cuda()

            self.spatial_model = torch.nn.DataParallel(
                self.spatial_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.spatial_model)
            self.spatial_model = self.spatial_model.cuda()

        # Resuming checkpoint
        self.best_accuracy = 0.0
        '''
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    "=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
                #self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

            #self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_accuracy = checkpoint['best_accuracy']
            print("=> loaded checkpoint '{}' (epoch {}), best prediction {}"
                  .format(args.resume, checkpoint['epoch'], self.best_accuracy))
        '''

    def training(self, epoch):
        train_loss = 0.0
        self.temporal_model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            rgbs, flows, targets = sample['rgb'], sample['flow'], sample[
                'label']
            targets = targets.view(-1, 1).float()
            if self.args.cuda:
                rgbs, flows, targets = rgbs.cuda(), flows.cuda(), targets.cuda(
                )

            self.temporal_optimizer.zero_grad()
            self.spatial_optimizer.zero_grad()

            temporal_output = self.temporal_model(flows)
            spatial_output = self.spatial_model(rgbs)

            temporal_loss = self.temporal_criterion(temporal_output, targets)
            spatial_loss = self.spatial_criterion(spatial_output, targets)

            temporal_loss.backward()
            spatial_loss.backward()

            self.temporal_optimizer.step()
            self.spatial_optimizer.step()

            train_loss += temporal_loss.item()
            train_loss += spatial_loss.item()

            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_temporal_loss_iter',
                                   temporal_loss.item(),
                                   i + num_img_tr * epoch)
            self.writer.add_scalar('train/total_spatial_loss_iter',
                                   spatial_loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            #if i % (num_img_tr // 10) == 0:
            #    global_step = i + num_img_tr * epoch
            #    self.summary.visualize_image(self.writer, images, targets.squeeze(1).cpu().numpy(), output.squeeze(1).data.cpu().numpy(), global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + rgbs.data.shape[0]))
        print('Loss: %.3f' % train_loss)

    def validation(self, epoch):
        self.temporal_model.eval()
        self.spatial_model.eval()

        self.top1_eval.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            rgbs, flows, targets = sample['rgb'], sample['flow'], sample[
                'label']
            targets = targets.view(-1, 1).float()
            if self.args.cuda:
                rgbs, flows, targets = rgbs.cuda(), flows.cuda(), targets.cuda(
                )
            with torch.no_grad():
                temporal_output = self.temporal_model(flows)
                spatial_output = self.spatial_model(rgbs)

            temporal_loss = self.temporal_criterion(temporal_output, targets)
            spatial_loss = self.spatial_criterion(spatial_output, targets)

            test_loss += temporal_loss.item()
            test_loss += spatial_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))

            pred = temporal_output.data.cpu(
            ).numpy() * self.temporal_weight + spatial_output.data.cpu().numpy(
            ) * self.spatial_weight
            targets = targets.cpu().numpy()
            # Add batch sample into evaluator
            self.top1_eval.add_batch(targets, pred)

        # Fast test during the training
        top1_acc = self.top1_eval.Accuracy()

        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/Acc', top1_acc, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + rgbs.data.shape[0]))
        print("Top1: acc:{}, best accuracy:{}".format(top1_acc,
                                                      self.best_accuracy))
        print("Sensitivity:{}, Specificity:{}".format(
            self.top1_eval.Sensitivity(), self.top1_eval.Specificity()))
        print("Confusion Maxtrix:\n{}".format(
            self.top1_eval.Confusion_Matrix()))
        print('Loss: %.3f' % test_loss)

        if top1_acc > self.best_accuracy:
            is_best = True
            self.best_accuracy = top1_acc
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'temporal_state_dict':
                    self.temporal_model.module.state_dict(),
                    'temporal_optimizer': self.temporal_optimizer.state_dict(),
                    'spatial_state_dict':
                    self.spatial_model.module.state_dict(),
                    'spatial_optimizer': self.spatial_optimizer.state_dict(),
                    'best_accuracy': self.best_accuracy,
                    'sensitivity': self.top1_eval.Sensitivity(),
                    'specificity': self.top1_eval.Specificity(),
                }, is_best)
示例#27
0
    def train(self, fix_net_weights=False):

        # have config valid_batch_size, and ignored drop_last.
        data_loader = self.run_manager.run_config.train_loader
        iter_per_epoch = len(data_loader)
        total_iteration = iter_per_epoch * self.run_manager.run_config.epochs
        self.update_scheduler = self.arch_search_config.get_update_schedule(
            iter_per_epoch)

        if fix_net_weights:  # used to debug
            data_loader = [(0, 0)] * iter_per_epoch
            print('Train Phase close for debug')

        # arch_parameter update frequency and times in each iteration.
        #update_schedule = self.arch_search_config.get_update_schedule(iter_per_epoch)

        # pay attention here, total_epochs include warmup epochs
        epoch_time = AverageMeter()
        end_epoch = time.time()
        # TODO: in train phase, set inverse weight back. ==> normal arch_weight
        self.net.set_inverse_weight()
        # TODO : use start_epochs
        for epoch in range(self.start_epoch,
                           self.run_manager.run_config.epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Train Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='search')
            self.run_manager.scheduler.step(epoch)
            train_lr = self.run_manager.scheduler.get_lr()
            arch_lr = self.arch_optimizer.param_groups[0]['lr']
            self.net.set_tau(self.arch_search_config.tau_max -
                             (self.arch_search_config.tau_max -
                              self.arch_search_config.tau_min) * (epoch) /
                             (self.run_manager.run_config.epochs))
            tau = self.net.get_tau()
            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            #valid_data_time = AverageMeter()
            valid_losses = AverageMeter()
            valid_accs = AverageMeter()
            valid_mious = AverageMeter()
            valid_fscores = AverageMeter()

            self.net.train()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            time_left = epoch_time.average * (
                self.run_manager.run_config.epochs - epoch)
            common_log = '[*Train-Search* the {:}] Left={:} WLR={:} ALR={:} tau={:}'\
                .format(epoch_str, str(timedelta(seconds=time_left)) if epoch != 0 else None, train_lr, arch_lr, tau)
            self.logger.log(common_log, 'search')

            end = time.time()

            # single_path init
            _, network_index = self.net.get_network_arch_hardwts_with_constraint(
            )
            _, aspp_index = self.net.get_aspp_hardwts_index()
            single_path = self.net.sample_single_path(
                self.run_manager.run_config.nb_layers, aspp_index,
                network_index)

            for i, (datas, targets) in enumerate(data_loader):
                #print(self.net.single_path)
                #print(i)
                #if i == 59: break
                if not fix_net_weights:
                    if torch.cuda.is_available():
                        datas = datas.to(self.run_manager.device,
                                         non_blocking=True)
                        targets = targets.to(self.run_manager.device,
                                             non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end)
                    '''
                    if (i + 1) % self.arch_search_config.sample_arch_frequency == 0:
                        _, network_index = self.net.get_network_arch_hardwts_with_constraint()
                        _, aspp_index = self.net.get_aspp_hardwts_index()
                        single_path = self.net.sample_single_path(self.run_manager.run_config.nb_layers, aspp_index, network_index)
                    '''
                    logits = self.net.single_path_forward(
                        datas, single_path)  # super network gdas forward
                    # loss
                    ce_loss = self.run_manager.criterion(logits, targets)
                    entropy_reg = self.net.calculate_entropy(
                        single_path
                    )  # todo: pay attention, entropy is unnormalized, should use small lambda
                    #print('entropy_reg:', entropy_reg)
                    loss = self.run_manager.add_regularization_loss(
                        ce_loss, entropy_reg)
                    #loss = self.run_manager.criterion(logits, targets)
                    # metrics and update
                    evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    acc = evaluator.Pixel_Accuracy()
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    losses.update(loss.data.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()

                    self.run_manager.optimizer.step()

                    if (
                            i + 1
                    ) % self.arch_search_config.sample_arch_frequency == 0 or (
                            i + 1
                    ) == iter_per_epoch:  # at the i-th iteration, update arch_parameters update_scheduler[i] times.
                        valid_datas, valid_targets = self.run_manager.run_config.valid_next_batch
                        if torch.cuda.is_available():
                            valid_datas = valid_datas.to(
                                self.run_manager.device, non_blocking=True)
                            valid_targets = valid_targets.to(
                                self.run_manager.device, non_blocking=True)
                        else:
                            raise ValueError('do not support cpu version')

                        _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                        )  # set self.hardwts again
                        _, aspp_index = self.net.get_aspp_hardwts_index()
                        single_path = self.net.sample_single_path(
                            self.run_manager.run_config.nb_layers, aspp_index,
                            network_index)
                        logits = self.net.single_path_forward(
                            valid_datas, single_path)

                        ce_loss = self.run_manager.criterion(
                            logits, valid_targets)
                        entropy_reg = self.net.calculate_entropy(single_path)
                        loss = self.run_manager.add_regularization_loss(
                            ce_loss, entropy_reg)

                        # metrics and update
                        valid_evaluator = Evaluator(
                            self.run_manager.run_config.nb_classes)
                        valid_evaluator.add_batch(valid_targets, logits)
                        acc = valid_evaluator.Pixel_Accuracy()
                        miou = valid_evaluator.Mean_Intersection_over_Union()
                        fscore = valid_evaluator.Fx_Score()
                        valid_losses.update(loss.data.item(), datas.size(0))
                        valid_accs.update(acc.item(), datas.size(0))
                        valid_mious.update(miou.item(), datas.size(0))
                        valid_fscores.update(fscore.item(), datas.size(0))

                        self.net.zero_grad()
                        loss.backward()  # release computational graph
                        # update arch_parameters per '{:}'.format(arch_param_update_frequency)
                        self.arch_optimizer.step()

                    # batch_time of one iter of train and valid.
                    batch_time.update(time.time() - end)
                    end = time.time()

                    # in other case, calculate metrics normally
                    # train_print_freq == sample_arch_freq
                    if (
                            i + 1
                    ) % self.run_manager.run_config.train_print_freq == 0 or (
                            i + 1) == iter_per_epoch:
                        Wstr = '|*Search*|' + time_string(
                        ) + '[{:}][iter{:03d}/{:03d}]'.format(
                            epoch_str, i + 1, iter_per_epoch)
                        Tstr = '|Time    | {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                            batch_time=batch_time, data_time=data_time)
                        Bstr = '|Base    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=losses, acc=accs, miou=mious, fscore=fscores)
                        Astr = '|Arch    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=valid_losses,
                            acc=valid_accs,
                            miou=valid_mious,
                            fscore=valid_fscores)
                        self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr +
                                        '\n' + Astr,
                                        mode='search')

            # update visdom
            if self.vis is not None:
                self.vis.visdom_update(epoch, 'loss',
                                       [losses.average, valid_losses.average])
                self.vis.visdom_update(epoch, 'accuracy',
                                       [accs.average, valid_accs.average])
                self.vis.visdom_update(epoch, 'miou',
                                       [mious.average, valid_mious.average])
                self.vis.visdom_update(
                    epoch, 'f1score', [fscores.average, valid_fscores.average])

            #torch.cuda.empty_cache()
            # update epoch_time
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()

            epoch_str = '{:03d}/{:03d}'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            log = '[{:}] train :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n' \
                  '[{:}] valid :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n'.format(
                epoch_str, losses.average, accs.average, mious.average, fscores.average,
                epoch_str, valid_losses.average, valid_accs.average, valid_mious.average, valid_fscores.average
            )
            self.logger.log(log, mode='search')

            self.logger.log(
                '<<<---------->>> Super Network decoding <<<---------->>> ',
                mode='search')
            actual_path, cell_genotypes = self.net.network_cell_arch_decode()
            #print(cell_genotypes)
            new_genotypes = []
            for _index, genotype in cell_genotypes:
                xlist = []
                print(_index, genotype)
                for edge_genotype in genotype:
                    for (node_str, select_index) in edge_genotype:
                        xlist.append((node_str, self.run_manager.run_config.
                                      conv_candidates[select_index]))
                new_genotypes.append((_index, xlist))
            log_str = 'The {:} decode network:\n' \
                      'actual_path = {:}\n' \
                      'genotype:'.format(epoch_str, actual_path)
            for _index, genotype in new_genotypes:
                log_str += 'index: {:} arch: {:}\n'.format(_index, genotype)
            self.logger.log(log_str, mode='network_space', display=False)

            # TODO: perform save the best network ckpt
            # 1. save network_arch_parameters and cell_arch_parameters
            # 2. save weight_parameters
            # 3. weight_optimizer.state_dict
            # 4. arch_optimizer.state_dict
            # 5. training process
            # 6. monitor_metric and the best_value
            # get best_monitor in valid phase.
            val_monitor_metric = get_monitor_metric(
                self.run_manager.monitor_metric, valid_losses.average,
                valid_accs.average, valid_mious.average, valid_fscores.average)
            is_best = self.run_manager.best_monitor < val_monitor_metric
            self.run_manager.best_monitor = max(self.run_manager.best_monitor,
                                                val_monitor_metric)
            # 1. if is_best : save_current_ckpt
            # 2. if can be divided : save_current_ckpt

            #self.run_manager.save_model(epoch, {
            #    'arch_optimizer': self.arch_optimizer.state_dict(),
            #}, is_best=True, checkpoint_file_name=None)
            # TODO: have modification on checkpoint_save semantics
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch +
                        1) == self.run_manager.run_config.epochs or is_best:
                checkpoint = {
                    'state_dict':
                    self.net.state_dict(),
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.scheduler.state_dict(),
                    'arch_optimizer':
                    self.arch_optimizer.state_dict(),
                    'best_monitor': (self.run_manager.monitor_metric,
                                     self.run_manager.best_monitor),
                    'warmup':
                    False,
                    'start_epochs':
                    epoch + 1,
                }
                checkpoint_arch = {
                    'actual_path': actual_path,
                    'cell_genotypes': cell_genotypes,
                }
                filename = self.logger.path(mode='search', is_best=is_best)
                filename_arch = self.logger.path(mode='arch', is_best=is_best)
                save_checkpoint(checkpoint,
                                filename,
                                self.logger,
                                mode='search')
                save_checkpoint(checkpoint_arch,
                                filename_arch,
                                self.logger,
                                mode='arch')
示例#28
0
class Trainer(object):

    def __init__(self, args, modelConfig, inputH5Path):

        # Get training parameters
        hyperpars = args["hyperparameters"]
        archpars = args["architecture"]

        # Get model config
        structList = modelConfig["structList"]
        nclass = len(structList) + 1  # + 1 for background class

        args["nclass"] = nclass
        args["inputH5Path"] = inputH5Path
        if torch.cuda.device_count() and torch.cuda.is_available():
            print('Using GPU...')
            args["cuda"] = True
            deviceCount = torch.cuda.device_count()
            print('GPU device count: ', deviceCount)
        else:
            print('using CPU...')
            args["cuda"] = False

        # Use default args where missing
        defPars = {'fineTune': False, 'resumeFromCheckpoint': None, 'validate': True,
                   'evalInterval': 1}
        defHyperpars = {'startEpoch': 0}
        for key in defPars.keys():
            if not key in args.keys():
                args[key] = defPars[key]
        for key in defHyperpars.keys():
            if not key in hyperpars.keys():
                hyperpars[key] = defHyperpars[key]

        args["hyperparameters"] = hyperpars
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloaders
        kwargs = {'num_workers': 1, 'pin_memory': True}
        train_set = customData(self.args, split='Train')
        self.train_loader = DataLoader(train_set, batch_size=hyperpars["batchSize"], shuffle=True, drop_last=True,
                                       **kwargs)
        val_set = customData(self.args, split='Val')
        self.val_loader = DataLoader(val_set, batch_size=hyperpars["batchSize"], shuffle=False, drop_last=False,
                                     **kwargs)
        test_set = customData(self.args, split='Test')
        self.test_loader = DataLoader(test_set, batch_size=hyperpars["batchSize"], shuffle=False, drop_last=False,
                                      **kwargs)

        # Define network
        model = DeepLab(num_classes=args["nclass"],
                        backbone='resnet',
                        output_stride=archpars["outStride"],
                        sync_bn=archpars["sync_bn"],
                        freeze_bn=archpars["freeze_bn"],
                        model_path=args["modelSavePath"])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': hyperpars["lr"]},
                        {'params': model.get_10x_lr_params(), 'lr': hyperpars["lr"] * 10}]

        # Define Optimizer
        optimizer_type = args["optimizer"]
        if optimizer_type.lower() == 'sgd':
            optimizer = torch.optim.SGD(train_params, momentum=hyperpars["momentum"],
                                        weight_decay=hyperpars["weightDecay"], nesterov=hyperpars["nesterov"])
        elif optimizer_type.lower() == 'adam':
            optimizer = torch.optim.Adam(train_params, lr=hyperpars["lr"],
                                         betas=(0.9, 0.999), eps=1e-08,
                                         weight_decay=hyperpars["weightDecay"])

        # Initialize weights
        print('Initializing weights...')
        initWeights = args["initWeights"]
        if initWeights["method"] == "classBalanced":
            # Use class balanced weights
            print('Using class-balanced weights.')
            class_weights_path = os.path.join(inputH5Path, 'classWeights.npy')
            if os.path.isfile(class_weights_path):
                print('reading weights from' + class_weights_path)
                weight = np.load(class_weights_path)
            else:
                weight = calculate_weights_labels(inputH5Path, self.train_loader, args["nclass"])
                np.save(class_weights_path, weight)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        # Define loss function
        self.criterion = SegmentationLosses(weight=weight, cuda=args["cuda"]).build_loss(mode=args["lossType"])
        self.model, self.optimizer = model, optimizer

        # Define evaluator
        self.evaluator = Evaluator(args["nclass"])

        # Define lr scheduler
        self.scheduler = LR_Scheduler(hyperpars["lrScheduler"], hyperpars["lr"],
                                      hyperpars["maxEpochs"], len(self.train_loader))

        # Use GPU(s) if available
        if args["cuda"]:
            self.model = torch.nn.DataParallel(self.model, list(range(deviceCount)))
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resume from previous checkpoint
        self.best_pred = 0.0
        if args["resumeFromCheckpoint"] is not None:
            if not os.path.isfile(args["resumeFromCheckpoint"]):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args["startEpoch"] = checkpoint['epoch']
            if args["cuda"]:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            # For fine-tuning:
            if not args["fineTune"]:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args["fineTune"]:
            args["startEpoch"] = 0

    def training(self, epoch):
        args = self.args
        hyperpars = args["hyperparameters"]
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if args["cuda"]:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show inference results
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, args, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * hyperpars["batchSize"] + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if not args["validate"]:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            print('Best model yet!')  # AI temp
            try:
                state_dict = self.model.module.state_dict()
            except AttributeError:
                state_dict = self.model.state_dict()
            # end mod
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

        if ((epoch + 1) == self.args.epochs):
            is_best = False
            try:
                state_dict = self.model.module.state_dict()
            except AttributeError:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                'filename': 'last_checkpoint.pth.tar',
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        """ Define Saver """
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        """ Define Tensorboard Summary """
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True, 'drop_last': True}

        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                """ if so, which trainloader to use? """
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255).cuda()

        """ Define network """
        if self.args.network == 'supernet':
            model = Model_search(self.nclass, 12, self.args)
        elif self.args.network == 'path_dense_supernet':
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy')
            cell_arch = np.load(cell_path)
            model = Model_layer_search(self.nclass, 12, self.args, alphas=cell_arch)

        elif self.args.network == 'path_baseline_supernet':
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy')
            cell_arch = np.load(cell_path)
            model = Model_layer_search_baseline(self.nclass, 12, self.args, alphas=cell_arch)

        else:
            return


        optimizer = torch.optim.SGD(
                model.weight_parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(self.model.arch_parameters(),
                                                    lr=args.arch_lr, betas=(0.9, 0.999),
                                                    weight_decay=args.arch_weight_decay)

        """ Define Evaluator """
        self.evaluator_1 = Evaluator(self.nclass)
        self.evaluator_2 = Evaluator(self.nclass)

        """ Define lr scheduler """
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loaderA), min_lr=args.min_lr)
        """ Using cuda """
        if args.cuda:
            self.model = self.model.cuda()

        """ mixed precision """
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None

            """ fix for current pytorch version with opt_level 'O1' """
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                        """ Hack to fix BN fprop without affine transformation """
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape, dtype=module.running_var.dtype,
                                           device=module.running_var.device), requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape, dtype=module.running_var.dtype,
                                            device=module.running_var.device), requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer, self.architect_optimizer] = amp.initialize(
                self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level,
                keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic")

            print('cuda finished')


        """ Using data parallel"""
        if args.cuda and len(self.args.gpu_ids) >1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        """ Resuming checkpoint """
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            """ if the weights are wrapped in module object we have to clean it """
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if (torch.cuda.device_count() > 1):
                    copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])
                else:
                    copy_state_dict(self.model.state_dict(), checkpoint['state_dict'])


    def training(self, epoch):
        train_loss = 0.0
        search_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output_1, output_2 = self.model(image)
            loss_1 = self.criterion(output_1, target)
            loss_2 = self.criterion(output_2, target)
            loss = loss_1 + loss_2
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()
            
            if epoch >= self.args.alpha_epoch:
                search = iter(self.train_loaderB).next()
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda(), target_search.cuda()

                self.architect_optimizer.zero_grad()
                output_search_1, output_search_2 = self.model(image_search)

                arch_loss_1 = self.criterion(output_search_1, target_search)
                arch_loss_2 = self.criterion(output_search_2, target_search)
                arch_loss = arch_loss_1 + arch_loss_2
                if self.use_amp:
                    with amp.scale_loss(arch_loss, self.architect_optimizer) as arch_scaled_loss:
                       arch_scaled_loss.backward()
                else:
                    arch_loss.backward()
                self.architect_optimizer.step()
                search_loss += arch_loss.item()
                
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f --Search loss: %.3f' \
                % (train_loss/(i+1), search_loss/(i+1)))

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        self.decoder_save(epoch, miou=None, evaluation=False)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator_1.reset()
        self.evaluator_2.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output_1, output_2 = self.model(image)
            loss_1 = self.criterion(output_1, target)
            loss_2 = self.criterion(output_2, target)
            loss = loss_1 + loss_2
            test_loss += loss.item()

            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))

            output_1 = torch.argmax(output_1, axis=1)
            output_2 = torch.argmax(output_2, axis=1)

            """ Add batch sample into evaluator"""
            self.evaluator_1.add_batch(target, output_1)
            self.evaluator_2.add_batch(target, output_2)
        mIoU_1 = self.evaluator_1.Mean_Intersection_over_Union()
        mIoU_2 = self.evaluator_2.Mean_Intersection_over_Union()

        """ FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() """
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/classifier_1/mIoU', mIoU_1, epoch)
        self.writer.add_scalar('val/classifier_2/mIoU', mIoU_2, epoch)

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.test_batch_size + image.data.shape[0]))
        print('Loss: %.3f' % test_loss)
        new_pred = (mIoU_1 + mIoU_2)/2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

        """ decode the arch """
        self.decoder_save(epoch, miou=new_pred, evaluation=True)


    def decoder_save(self, epoch, miou=None, evaluation=False):

        
        num = str(epoch)
        if evaluation:
            num = num + '_eval'
        try:
            dir_name = os.path.join(self.saver.experiment_dir, num)
            os.makedirs(dir_name)
        except:
            print('folder path error')

        decoder = Decoder(None,
                          self.model.betas,
                          self.args.B)

        result_paths, result_paths_space = decoder.viterbi_decode()

        betas = self.model.betas.data.cpu().numpy()

        network_path_filename = os.path.join(dir_name,'network_path')
        beta_filename = os.path.join(dir_name, 'betas')

        np.save(network_path_filename, result_paths)
        np.save(beta_filename, betas)

        if miou != None:
            with open(os.path.join(dir_name, 'miou.txt'), 'w') as f:
                    f.write(str(miou))
        if evaluation:
            self.writer.add_text('network_path', str(result_paths), epoch+1000)
            self.writer.add_text('miou', str(miou), epoch+1000)
        else:
            self.writer.add_text('network_path', str(result_paths), epoch)
示例#30
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
示例#31
0
class Trainer(object):
    def __init__(self, args):
        warnings.filterwarnings('ignore')
        assert torch.cuda.is_available()
        torch.backends.cudnn.benchmark = True
        model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format(
            args.backbone, args.dataset, args.exp)
        if args.dataset == 'pascal':
            raise NotImplementedError
        elif args.dataset == 'cityscapes':
            kwargs = {
                'num_workers': args.workers,
                'pin_memory': True,
                'drop_last': True
            }
            dataset_loader, num_classes = dataloaders.make_data_loader(
                args, **kwargs)
            args.num_classes = num_classes
        elif args.dataset == 'marsh':
            kwargs = {
                'num_workers': args.workers,
                'pin_memory': True,
                'drop_last': True
            }
            dataset_loader, val_loader, test_loader, num_classes = dataloaders.make_data_loader(
                args, **kwargs)
            args.num_classes = num_classes
        else:
            raise ValueError('Unknown dataset: {}'.format(args.dataset))

        if args.backbone == 'autodeeplab':
            model = Retrain_Autodeeplab(args)
        else:
            raise ValueError('Unknown backbone: {}'.format(args.backbone))

        if args.criterion == 'Ohem':
            args.thresh = 0.7
            args.crop_size = [args.crop_size, args.crop_size] if isinstance(
                args.crop_size, int) else args.crop_size
            args.n_min = int((args.batch_size / len(args.gpu) *
                              args.crop_size[0] * args.crop_size[1]) // 16)
        criterion = build_criterion(args)

        model = nn.DataParallel(model).cuda()
        model.train()
        if args.freeze_bn:
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad = False
                    m.bias.requires_grad = False
        optimizer = optim.SGD(model.module.parameters(),
                              lr=args.base_lr,
                              momentum=0.9,
                              weight_decay=0.0001)

        max_iteration = len(dataset_loader) * args.epochs
        scheduler = Iter_LR_Scheduler(args, max_iteration, len(dataset_loader))

        start_epoch = 0

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume:
            if os.path.isfile(args.resume):
                print('=> loading checkpoint {0}'.format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print('=> loaded checkpoint {0} (epoch {1})'.format(
                    args.resume, checkpoint['epoch']))
                self.best_pred = checkpoint['best_pred']
            else:
                raise ValueError('=> no checkpoint found at {0}'.format(
                    args.resume))
        ##mergee
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        #kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataset_loader, val_loader, test_loader, num_classes

        self.criterion = criterion
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        #self.scheduler = scheduler
        self.scheduler = LR_Scheduler(
            "poly", args.lr, args.epochs,
            len(self.train_loader))  #removed None from second parameter.

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()
            cur_iter = epoch * len(self.train_loader) + i
            #self.scheduler(self.optimizer, cur_iter)# this scheduler did not work. let try other one for say 500 epochs.
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #print("I was here!!")
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        # save checkpoint every epoch
        is_best = False
        self.saver.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print("Classwise_IoU:")
        print(IoU)
        print('Loss: %.3f' % test_loss)
        print(self.evaluator.confusion_matrix)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)