예제 #1
0
파일: eval.py 프로젝트: syed-cbot/RFNet
    def __init__(self, args):
        self.args = args
        self.time_train = []

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        _, self.val_loader, _, self.num_class = make_data_loader(
            args, **kwargs)
        print('un_classes:' + str(self.num_class))

        # Define evaluator
        self.evaluator = Evaluator(self.num_class)

        # Define network
        self.resnet = resnet18(pretrained=True, efficient=False, use_bn=True)
        self.model = RFNet(self.resnet,
                           num_classes=self.num_class,
                           use_bn=True)

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()
            cudnn.benchmark = True  # accelarate speed
        print('Model loaded successfully!')

        # Load weights
        assert os.path.exists(
            args.weight_path), 'weight-path:{} doesn\'t exit!'.format(
                args.weight_path)
        self.new_state_dict = torch.load(
            os.path.join(args.weight_path, 'model_best.pth'))

        self.model = load_my_state_dict(self.model.module,
                                        self.new_state_dict['state_dict'])
예제 #2
0
파일: trnval.py 프로젝트: DotWang/DFC2020
    def __init__(self, args, student_model, teacher_model, src_loader,
                 trg_loader, val_loader, optimizer, teacher_optimizer):

        self.args = args
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.src_loader = src_loader
        self.trg_loader = trg_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.teacher_optimizer = teacher_optimizer
        # Define Evaluator
        self.evaluator = Evaluator(args.nclass)
        # Define lr scheduler
        # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                          args.epochs, len(trn_loader))
        #self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[3, 6, 9, 12], gamma=0.5)
        #ft
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                              milestones=[20],
                                                              gamma=0.5)
        self.best_pred = 0
        self.init_weight = 0.98
        # Define Saver
        self.saver = Saver(self.args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.evaluator = Evaluator(self.args.nclass)
예제 #3
0
    def __init__(self, model_path, source, target, cuda=False):
        self.source_set = spacenet.Spacenet(city=source,
                                            split='test',
                                            img_root=config.img_root)
        self.target_set = spacenet.Spacenet(city=target,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)
        self.target_loader = DataLoader(self.target_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
예제 #4
0
    def __init__(self, model_path, config, cuda=False):
        self.target=config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset, split='test', img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set, batch_size=16, shuffle=False, num_workers=2)

        self.target_set = []
        self.target_loader = []
        # load other domains
        for city in self.target:
            tmp = spacenet.Spacenet(city=city, split='test', img_root=config.img_root)
            self.target_set.append(tmp)
            self.target_loader.append(DataLoader(tmp, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                backbone=config.backbone,
                output_stride=config.out_stride,
                sync_bn=config.sync_bn,
                freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
예제 #5
0
파일: eval.py 프로젝트: jamycheung/ISSAFE
 def __init__(self, args, logger):
     self.args = args
     self.logger = logger
     self.time_train = []
     self.args.evaluate = True
     self.args.merge = True
     kwargs = {'num_workers': args.workers, 'pin_memory': False}
     _, self.val_loader, _, self.num_class = make_data_loader(
         args, **kwargs)
     print('un_classes:' + str(self.num_class))
     self.resize = args.crop_size if args.crop_size else [512, 1024]
     self.evaluator = Evaluator(self.num_class, self.logger)
     self.model = EDCNet(self.args.rgb_dim,
                         args.event_dim,
                         num_classes=self.num_class,
                         use_bn=True)
     if args.cuda:
         self.model = torch.nn.DataParallel(self.model,
                                            device_ids=self.args.gpu_ids)
         self.model = self.model.to(self.args.device)
         cudnn.benchmark = True
     print('Model loaded successfully!')
     assert os.path.exists(
         args.weight_path), 'weight-path:{} doesn\'t exit!'.format(
             args.weight_path)
     self.new_state_dict = torch.load(os.path.join(args.weight_path),
                                      map_location='cuda:0')
     self.model = load_my_state_dict(self.model.module,
                                     self.new_state_dict['state_dict'])
예제 #6
0
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            #test_img_root = '/home/home1/swarnakr/main/DomainAdaptation/satellite/' + city + '/' + 'test'
            #test = spacenet.Spacenet(city=city, split='test', img_root=test_img_root)
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            #train_img_root = '/home/home1/swarnakr/main/DomainAdaptation/satellite/' + city + '/' + 'train'
            #train = spacenet.Spacenet(city=city, split='train', img_root=train_img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
예제 #7
0
def validation(epoch, model, args, criterion, nclass, test_tag=False):
    model.eval()

    losses = 0.0

    evaluator = Evaluator(nclass)
    evaluator.reset()
    if test_tag == True:
        num_img = args.data_dict['num_valid']
    else:
        num_img = args.data_dict['num_test']
    for i in range(num_img):
        if test_tag == True:
            inputs = torch.FloatTensor(args.data_dict['valid_data'][i]).cuda()
            target = torch.FloatTensor(args.data_dict['valid_mask'][i]).cuda()
        else:
            inputs = torch.FloatTensor(args.data_dict['test_data'][i]).cuda()
            target = torch.FloatTensor(args.data_dict['test_mask'][i]).cuda()

        with torch.no_grad():
            output = model(inputs)
        loss_val = criterion(output, target)
        print('epoch: {0}\t'
              'iter: {1}/{2}\t'
              'loss: {loss:.4f}'.format(epoch + 1,
                                        i + 1,
                                        args.data_dict['num_train'],
                                        loss=loss_val))
        pred = output.data.cpu().numpy()
        target = target.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        evaluator.add_batch(target, pred)

        losses += loss_val

        if test_tag == True:
            #save input,target,pred
            pred_save_dir = './pred/'
            sitk.WriteImage(sitk.GetImageFromArray(inputs),
                            pred_save_dir + 'input_{}.nii.gz'.format(i))
            sitk.WriteImage(sitk.GetImageFromArray(target),
                            pred_save_dir + 'target_{}.nii.gz'.format(i))
            sitk.WriteImage(
                sitk.GetImageFromArray(pred),
                pred_save_dir + 'pred_{}_{}.nii.gz'.format(i, epoch))

    Acc = evaluator.Pixel_Accuracy()
    Acc_class = evaluator.Pixel_Accuracy_Class()
    mIoU = evaluator.Mean_Intersection_over_Union()
    FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
    if test_tag == True:
        print('Test:')
    else:
        print('Validation:')
    print('[Epoch: %d, numImages: %5d]' % (epoch, num_img))
    print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
        Acc, Acc_class, mIoU, FWIoU))
    print('Loss: %.3f' % losses)
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define weight
        self.temporal_weight = args.temporal_weight
        self.spatial_weight = args.spatial_weight

        # Define network
        temporal_model = Model(name='vgg16_bn', num_classes=101,
                               is_flow=True).get_model()
        spatial_model = Model(name='vgg16_bn', num_classes=101,
                              is_flow=False).get_model()

        # Define Optimizer
        #optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        temporal_optimizer = torch.optim.Adam(temporal_model.parameters(),
                                              lr=args.temporal_lr)
        spatial_optimizer = torch.optim.Adam(spatial_model.parameters(),
                                             lr=args.spatial_lr)

        # Define Criterion
        self.temporal_criterion = nn.BCELoss().cuda()
        self.spatial_criterion = nn.BCELoss().cuda()

        self.temporal_model, self.temporal_optimizer = temporal_model, temporal_optimizer
        self.spatial_model, self.spatial_optimizer = spatial_model, spatial_optimizer

        # Define Evaluator
        self.top1_eval = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.temporal_model = torch.nn.DataParallel(
                self.temporal_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.temporal_model)
            self.temporal_model = self.temporal_model.cuda()

            self.spatial_model = torch.nn.DataParallel(
                self.spatial_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.spatial_model)
            self.spatial_model = self.spatial_model.cuda()

        # Resuming checkpoint
        self.best_accuracy = 0.0
        '''
예제 #9
0
파일: train.py 프로젝트: jamycheung/ISSAFE
    def __init__(self, args):
        self.args = args
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.logger = self.saver.create_logger()

        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        self.model = EDCNet(args.rgb_dim, args.event_dim, num_classes=self.nclass, use_bn=True)
        train_params = [{'params': self.model.random_init_params(),
                         'lr': 10*args.lr, 'weight_decay': 10*args.weight_decay},
                        {'params': self.model.fine_tune_params(),
                         'lr': args.lr, 'weight_decay': args.weight_decay}]
        self.optimizer = torch.optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay)
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.to(self.args.device)
        if args.use_balanced_weights:
            root_dir = Path.db_root_dir(args.dataset)[0] if isinstance(Path.db_root_dir(args.dataset), list) else Path.db_root_dir(args.dataset)
            classes_weights_path = os.path.join(root_dir,
                                                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass, classes_weights_path)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion_event = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='event')
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader), warmup_epochs=5)

        self.evaluator = Evaluator(self.nclass, self.logger)
        self.saver.save_model_summary(self.model)
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0
예제 #10
0
    def __init__(self, backbone, num_classes, train_data_loader, test_data_loader, use_cuda , model):

        self.train_data_loader = train_data_loader
        self.test_data_loader = test_data_loader

        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
        self.prunner = FilterPrunner(self.model,use_cuda)
        self.model.train()
        self.evaluator = Evaluator(num_classes)
        self.use_cuda = use_cuda
예제 #11
0
 def define_evaluators(self):
     if self.args.dataset == "isi_intensity":
         weights = np.array([1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1]) > 0.5
     elif (self.args.dataset == "isi_rgb"
           or self.args.dataset == "isi_rgb_temporal"
           or self.args.dataset == 'isi_multiview' or self.args.dataset
           == 'isi_multiview_2018') and self.args.skip_classes is not None:
         weights = self.args.skip_weights > 0.5
     else:
         weights = None
     self.evaluator = Evaluator(self.nclass, weights)
예제 #12
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        # PATH = args.path
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = SCNN(nclass=self.nclass,backbone=args.backbone,output_stride=args.out_stride,cuda = args.cuda)

        # Define Optimizer
        optimizer = torch.optim.SGD(model.parameters(),args.lr, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
예제 #13
0
 def __init__(self, network, train_dataloader, eval_dataloader, criterion,
              optimizer, visualizer, experiment_name, config):
     self.config = config
     self.network = network
     self.train_dataloader = train_dataloader
     self.eval_dataloader = eval_dataloader
     self.criterion = criterion
     self.optimizer = optimizer
     self.visualizer = visualizer
     self.experiment_name = experiment_name
     self.evaluator = Evaluator(config['n_classes'])
예제 #14
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1])
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(config)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=config.backbone,
                        output_stride=config.out_stride,
                        sync_bn=config.sync_bn,
                        freeze_bn=config.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': config.lr},
                        {'params': model.get_10x_lr_params(), 'lr': config.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=config.momentum,
                                    weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr,
                                      config.T, len(self.train_loader),
                                      config.lr_step, config.warmup_epochs)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint, map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, args.start_epoch))
예제 #15
0
파일: utils.py 프로젝트: wtsitp/DFQ
def forward_all(net_inference, dataloader, visualize=False, opt=None):
    evaluator = Evaluator(21)
    evaluator.reset()
    with torch.no_grad():
        for ii, sample in enumerate(dataloader):
            image, label = sample['image'].cuda(), sample['label'].cuda()

            activations = net_inference(image)

            image = image.cpu().numpy()
            label = label.cpu().numpy().astype(np.uint8)

            logits = activations[list(activations.keys(
            ))[-1]] if type(activations) != torch.Tensor else activations
            pred = torch.max(logits, 1)[1].cpu().numpy().astype(np.uint8)

            evaluator.add_batch(label, pred)

            # print(label.shape, pred.shape)
            if visualize:
                for jj in range(sample["image"].size()[0]):
                    segmap_label = decode_segmap(label[jj], dataset='pascal')
                    segmap_pred = decode_segmap(pred[jj], dataset='pascal')

                    img_tmp = np.transpose(image[jj], axes=[1, 2, 0])
                    img_tmp *= (0.229, 0.224, 0.225)
                    img_tmp += (0.485, 0.456, 0.406)
                    img_tmp *= 255.0
                    img_tmp = img_tmp.astype(np.uint8)

                    cv2.imshow('image', img_tmp[:, :, [2, 1, 0]])
                    cv2.imshow('gt', segmap_label)
                    cv2.imshow('pred', segmap_pred)
                    cv2.waitKey(0)

    Acc = evaluator.Pixel_Accuracy()
    Acc_class = evaluator.Pixel_Accuracy_Class()
    mIoU = evaluator.Mean_Intersection_over_Union()
    FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
    print("Acc: {}".format(Acc))
    print("Acc_class: {}".format(Acc_class))
    print("mIoU: {}".format(mIoU))
    print("FWIoU: {}".format(FWIoU))
    if opt is not None:
        with open("seg_result.txt", 'a+') as ww:
            ww.write(
                "{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n"
                .format(opt.dataset, opt.quantize, opt.relu, opt.equalize,
                        opt.absorption, opt.correction, opt.clip_weight,
                        opt.distill_range))
            ww.write("Acc: {}, Acc_class: {}, mIoU: {}, FWIoU: {}\n\n".format(
                Acc, Acc_class, mIoU, FWIoU))
예제 #16
0
    def __init__(self, args):
        self.args = args

        self.saver = Saver(args)
        self.saver.save_experiment_config()

        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # self.model = OCRNet(self.nclass)
        self.model = build_model(2, [32, 32], '44330020')
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay,
                                         nesterov=args.nesterov)
        if args.use_balanced_weights:
            weight = torch.tensor([0.2, 0.8], dtype=torch.float32)
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.evaluator = Evaluator(self.nclass)
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        if args.cuda:
            self.model = self.model.cuda()

        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0
    def __init__(self, args):
        self.args = args
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_dataloader(
            args.dataset, args.base_size, args.crop_size, args.batch_size,
            args.overfit, **kwargs)
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
        self.evaluator = Evaluator(self.nclass)

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(constants.DATASET_ROOT,
                                                args.dataset,
                                                'class_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weights_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        checkpoint = torch.load(args.resume)
        if args.cuda:
            self.model.module.load_state_dict(checkpoint['state_dict'])
        else:
            self.model.load_state_dict(checkpoint['state_dict'])
        print(
            f'=> loaded checkpoint {args.resume} (epoch {checkpoint["epoch"]})'
        )

        self.visualizations_folder = os.path.join(
            os.path.dirname(os.path.realpath(args.resume)),
            constants.VISUALIZATIONS_FOLDER)
        if not os.path.exists(self.visualizations_folder):
            os.makedirs(self.visualizations_folder)
예제 #18
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Dataloader
        if args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2,4,23,3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()


        # Resuming checkpoint
        self.best_pred = 0.0
        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0 
예제 #19
0
def evals(arch='res18'):
    """
    class IoU & mIoU, Acc & mAcc
    """
    trainset, valset, testset = build_datasets(dataset='SUNRGBD',
                                               base_size=512,
                                               crop_size=512)

    # load model
    if arch == 'res18':
        model = BiSeNet(37, context_path='resnet18', in_planes=32)
        load_state_dict(
            model,
            ckpt_path=
            'runs/SUNRGBD/kd_pi_lr1e-3_Jul28_002404/checkpoint.pth.tar')
    elif arch == 'res101':
        model = BiSeNet(37, context_path='resnet101', in_planes=64)
        load_state_dict(
            model,
            ckpt_path=
            'runs/SUNRGBD/res101_inp64_deconv_Jul26_205859/checkpoint.pth.tar')
    else:
        raise NotImplementedError

    model.eval()
    model.cuda()

    evaluator = Evaluator(testset.num_classes)
    evaluator.reset()

    print('imgs:', len(testset))
    for sample in tqdm(testset):  # already transfrom
        image, target = sample['img'], sample['target']
        image = image.unsqueeze(0).cuda()
        pred = model(image)
        pred = F.interpolate(pred,
                             size=(512, 512),
                             mode='bilinear',
                             align_corners=True)
        pred = torch.argmax(pred, dim=1).squeeze().cpu().numpy()
        target = target.numpy()
        evaluator.add_batch(target, pred)

    print('PixelAcc:', evaluator.Pixel_Accuracy())

    print('mAcc')  # 各类的 acc 均值
    Accs = evaluator.Acc_of_each_class()
    print(np.nanmean(Accs))  # mAcc, mean of non-NaN elements
    approx_print(Accs)

    print('mIoU')
    IOUs = evaluator.IOU_of_each_class()
    print(np.nanmean(IOUs))  # mIoU
    approx_print(IOUs)
    def __init__(self, model_path, config, bn, save_path, save_batch, sample_number, trial=100, cuda=False):
        self.bn = bn
        self.target=config.all_dataset
        self.target.remove(config.dataset)
        self.sample_number = sample_number
        # load source domain
        #self.source_set = spacenet.Spacenet(city=config.dataset, split='test', img_root=config.img_root, needs to be changed)
        #self.source_loader = DataLoader(self.source_set, batch_size=16, shuffle=False, num_workers=2)
        self.source_loader = None
        self.save_path = save_path
        self.save_batch = save_batch
        self.trial = trial
        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city, split='val', img_root=config.img_root, gt_root = config.gt_root, mean_std=config.mean_std, if_augment=config.if_augment, repeat_count=config.repeat_count)
            self.target_set.append(test)
            self.target_loader.append(DataLoader(test, batch_size=16, shuffle=False, num_workers=2))

            train = spacenet.Spacenet(city=city, split='train', img_root=config.img_root, gt_root = config.gt_root, mean_std=config.mean_std, if_augment=config.if_augment, repeat_count=config.repeat_count, sample_number= sample_number)
            self.target_trainset.append(train)
            self.target_trainloader.append(DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

            
        self.model = DeepLab(num_classes=2,
                backbone=config.backbone,
                output_stride=config.out_stride,
                sync_bn=config.sync_bn,
                freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        #self.model.load_state_dict(self.checkpoint)
        self.model.load_state_dict(self.checkpoint['model'])

        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
예제 #21
0
    def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        val_set = pascal.VOCSegmentation(args, split='val')
        self.nclass = val_set.NUM_CLASSES
        self.val_loader = DataLoader(val_set,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     **kwargs)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            print('device_ids', self.args.gpu_ids)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
예제 #22
0
def train(train_queue, valid_queue, model, architect, criterion, optimizer,
          lr):
    objs = utils.AvgrageMeter()  # 用于保存loss的值
    accs = utils.AvgrageMeter()
    MIoUs = utils.AvgrageMeter()
    fscores = utils.AvgrageMeter()

    # device = torch.device('cuda' if torch.cuda.is_avaitargetsle() else 'cpu')
    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    for step, (input, target) in enumerate(
            train_queue):  #每个step取出一个batch,batchsize是64(256个数据对)
        model.train()
        n = input.size(0)

        input = input.to(device)
        target = target.to(device)

        # get a random minibatch from the search queue with replacement
        input_search, target_search = next(iter(valid_queue))
        input_search = input_search.to(device)
        target_search = target_search.to(device)

        architect.step(input,
                       target,
                       input_search,
                       target_search,
                       lr,
                       optimizer,
                       unrolled=args.unrolled)

        optimizer.zero_grad()
        logits = model(input)
        logits = logits.to(device)
        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
        optimizer.step()

        #prec = utils.Accuracy(logits, target)
        #prec1 = utils.MIoU(logits, target, dataset_classes)
        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
예제 #23
0
파일: val_best.py 프로젝트: bmhopkinson/CSN
class Trainer(object):
    def __init__(self,args):
        warnings.filterwarnings('ignore')
        assert torch.cuda.is_available()
        torch.backends.cudnn.benchmark = True
        model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format(args.backbone, args.dataset, args.exp)
        if args.dataset == 'pascal':
            raise NotImplementedError
        elif args.dataset == 'cityscapes':
            kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True}
            dataset_loader, num_classes = dataloaders.make_data_loader(args, **kwargs)
            args.num_classes = num_classes
        elif args.dataset == 'marsh' :
            kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True}
            dataset_loader,val_loader, test_loader, num_classes = dataloaders.make_data_loader(args, **kwargs)
            args.num_classes = num_classes
        else:
            raise ValueError('Unknown dataset: {}'.format(args.dataset))

        if args.backbone == 'autodeeplab':
            model = Retrain_Autodeeplab(args)
            model.load_state_dict(torch.load(r"./run/marsh/deeplab-autodeeplab/model_best.pth.tar")['state_dict'], strict=False)
        else:
            raise ValueError('Unknown backbone: {}'.format(args.backbone))

       optimizer = optim.SGD(model.module.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=0.0001)


        if args.criterion == 'Ohem':
            args.thresh = 0.7
            args.crop_size = [args.crop_size, args.crop_size] if isinstance(args.crop_size, int) else args.crop_size
            args.n_min = int((args.batch_size / len(args.gpu) * args.crop_size[0] * args.crop_size[1]) // 16)
        criterion = build_criterion(args)
		
		
        model = nn.DataParallel(model).cuda()
        ##mergee 
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        #kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataset_loader,val_loader, test_loader, num_classes

        self.criterion = criterion
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        #self.scheduler = scheduler
        self.scheduler = LR_Scheduler("poly",args.lr, args.epochs, len(self.train_loader)) #removed None from second parameter. 
예제 #24
0
    def __init__(self, args):
        self.args = args
        self.args.batchnorm_function = torch.nn.BatchNorm2d
        # Define Dataloader
        self.nclass = self.args.num_classes
        # Define network
        model = generate_net(self.args)

        self.model = model
        self.evaluator = Evaluator(self.nclass)
        self.criterion = SegmentationLosses(cuda=True).build_loss(mode='ce')
        # Using cuda
        if self.args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        _, _, _ = load_pretrained_mode(self.model,
                                       checkpoint_path=self.args.resume)
예제 #25
0
    def __init__(self, args):
        self.args = args

        # configure datasetpath
        self.baseroot = None
        if args.dataset == 'pascal':
            self.baseroot = '/path/to/your/VOCdevkit/VOC2012/'
        ''' no support,
        # if you want train on these
        # you need modefy here 
        # refer to /dataloader/datasets/pascal to 
        #implement the corresponding constructor to dataset

        elif args.dataset == 'cityscapes':
            self.baseroot = '/path/to/your/cityscapes/'
        elif args.dataset == 'sbd':
            self.baseroot = '/path/to/your/sbd/'
        elif args.dataset == 'coco':
            self.baseroot = '/path/to/your/coco/'
        '''

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.test_loader, self.nclass = make_data_loader(
            self.baseroot, args, **kwargs)

        #define net model
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=False,
                             freeze_bn=False).cuda()

        # self.model.module.load_state_dict(torch.load('./model_best.pth.tar', map_location='cpu'))
        self.evaluator = Evaluator(self.nclass)

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        self.best_pred = 0.0

        if not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        if args.cuda:
            self.model.module.load_state_dict(checkpoint['state_dict'])
        else:
            self.model.load_state_dict(checkpoint['state_dict'])

        self.best_pred = checkpoint['best_pred']
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
예제 #26
0
    def __init__(self,args):
        self.args = args
        self.nclass  = 4
        self.save_fold = 'brain_re/brain_cedice'
        mkdir(self.save_fold)
        self.name = self.save_fold.split('/')[-1].split('_')[-1]
        #===for brain==========================
        # self.nclass = 4
        # self.save_fold = 'brain_re'
        #======================================
        net = segModel(self.args,self.nclass)
        net.build_model()
        model = net.model
        #load params
        resume = args.resume
        self.model = torch.nn.DataParallel(model)
        self.model = self.model.cuda()
        print('==>Load model...')
        if not resume is None:
            checkpoint = torch.load(resume)
            # model.load_state_dict(checkpoint)
            model.load_state_dict(checkpoint['state_dict'])
        self.model = model
        print('==>loding loss func...')
        self.criterion = SegmentationLosses(cuda=args.cuda).build_loss(mode=args.loss_type)

        #define evaluator
        self.evaluator = Evaluator(self.nclass)

        #get data path
        root_path = Path.db_root_dir(self.args.dataset)
        if self.args.dataset == 'drive':
            folder = 'test'
            self.test_img = os.path.join(root_path, folder, 'images')
            self.test_label = os.path.join(root_path, folder, '1st_manual')
            self.test_mask = os.path.join(root_path, folder, 'mask')
        elif self.args.dataset == 'brain':
            path = root_path+'/Bra-pickle'
            valid_path = '../data/Brain/test.csv'
            self.valid_set = get_dataset(path,valid_path)
        print('loading test data...')

        #define data
        self.test_loader = None
예제 #27
0
    def __init__(self, args):
        self.args = args

        # Define network
        model = DeepLab(num_classes=32,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        #         self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(32)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        time_start = time.time()
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
예제 #28
0
    def __init__(self, args):
        if not os.path.isfile(args.model):
            raise RuntimeError("no checkpoint found at '{}'".fromat(args.model))
        self.args = args
        self.color_map = get_pascal_labels()
        self.test_loader, self.ids, self.nclass = make_data_loader(args)

        #Define model
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=False,
                        freeze_bn=False)
        
        self.model = model
        device = torch.device('cpu')
        checkpoint = torch.load(args.model, map_location=device)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.evaluator = Evaluator(self.nclass)
예제 #29
0
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.city = config.dataset  #all_dataset
        self.save_path = save_path
        self.save_batch = save_batch

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        if 1:  #for city in self.target:
            train = spacenet.Spacenet(city=self.city,
                                      split='train',
                                      img_root=config.img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

        #if DA images
        self.checkpoint = torch.load(model_path)
        #'./train_log/' + self.config.dataset + '_da_' + city + '.pth')
        self.model.load_state_dict(self.checkpoint)
        if self.cuda:
            self.model = self.model.cuda()
예제 #30
0
    def __init__(self, config,  checkpoint_path='./snapshots/checkpoint_best.pth.tar'):
        self.config = config
        self.checkpoint_path = checkpoint_path

#        with open(self.config_file_path) as f:

        self.categories_dict = {"background": 0, "short_sleeve_top": 1, "long_sleeve_top": 2, "short_sleeve_outwear": 3,
                "long_sleeve_outwear": 4, "vest": 5, "sling": 6, "shorts": 7, "trousers": 8,
                "skirt": 9,  "short_sleeve_dress": 10, "long_sleeve_dress": 11,
                "vest_dress": 12, "sling_dress": 13}

#        self.categories_dict = {"background": 0, "meningioma": 1, "glioma": 2, "pituitary": 3}
        self.categories_dict_rev = {v: k for k, v in self.categories_dict.items()}
        
        self.model = self.load_model()
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)

        self.num_classes = self.config['network']['num_classes']
        self.evaluator = Evaluator(self.num_classes)
        self.criterion = SegmentationLosses(weight=None, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
예제 #31
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
예제 #32
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)