Example #1
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader1, self.train_loader2, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(21, 12, self.criterion)
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.model, self.optimizer = model, optimizer

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Define Optimizer

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader1))

        self.architect = Architect(self.model, args)
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader1)
        num_img_tr = len(self.train_loader1)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            search = next(iter(self.train_loader2))
            image_search, target_search = search['image'], search['label']
            # print ('------------------------begin-----------------------')
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                image_search, target_search = image_search.cuda(
                ), target_search.cuda()
                # print ('cuda finish')

            self.architect.step(image_search, target_search)
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #2
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        if not args.test:
            self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        if self.args.norm == 'gn':
            norm = gn
        elif self.args.norm == 'bn':
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == 'abn':
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        if self.args.model	=='deeplabv3+':
             model = DeepLab(args=self.args,
                        num_classes=self.nclass,
                        freeze_bn=args.freeze_bn)
        elif self.args.model	=='deeplabv3':
             model = DeepLabv3(Norm=self.args.norm,
			backbone=args.backbone,
			output_stride=args.out_stride,
			num_classes=self.nclass,
			freeze_bn=args.freeze_bn)
        elif self.args.model == 'fpn':
            model = FPN (
                args=args,
                num_classes=self.nclass
            )
        '''
        model.cuda()
        summary(model, input_size=(3, 720, 1280))
        exit()
        '''

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.ft:
                args.start_epoch = 0
            else:
                args.start_epoch = checkpoint['epoch']
            
            if args.cuda:
                #self.model.module.load_state_dict(checkpoint['state_dict'])
                pretrained_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                #self.model.load_state_dict(checkpoint['state_dict'])
                pretrained_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.load_state_dict(state_dict)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        elif args.decoder is not None:
            if not os.path.isfile(args.decoder):
                raise RuntimeError("=> no checkpoint for decoder found at '{}'" .format(args.decoder))
            checkpoint = torch.load(args.decoder)
            args.start_epoch = 0 # As every time loads decoder only should be finetuning
            if args.cuda:
                decoder_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in decoder_dict.items():
                    if not 'aspp' in k:
                        continue
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                raise NotImplementedError("Please USE CUDA!!!")
            

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            continue
            #self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            '''
            if i % (num_img_tr // 10) == 0 and False:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
            '''

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch, inference=False):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        # init D
        model_D = FCDiscriminator(num_classes=19)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        optimizer_D = torch.optim.Adam(model_D.parameters(),
                                       lr=1e-4,
                                       betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\' + args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        self.model, self.optimizer = model, optimizer
        self.model_D, self.optimizer_D = model_D, optimizer_D

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model_D = torch.nn.DataParallel(self.model_D,
                                                 device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            patch_replication_callback(self.model_D)
            self.model = self.model.cuda()
            self.model_D = self.model_D.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        # labels for adversarial training
        source_label = 0
        target_label = 1

        loss_seg_value = 0.0
        loss_adv_target_value = 0.0
        loss_D_value = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            src_image, src_label, tgt_image = sample['src_image'], sample[
                'src_label'], sample['tgt_image']
            if self.args.cuda:
                src_image, src_label, tgt_image = src_image.cuda(
                ), src_label.cuda(), tgt_image.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            self.scheduler(self.optimizer_D, i, epoch, self.best_pred)
            self.optimizer_D.zero_grad()

            ## train G

            # don't accumulate grads in D
            for param in self.model_D.parameters():
                param.requires_grad = False

            # train with source
            src_output = self.model(src_image)
            loss_seg = self.criterion(src_output, src_label)
            loss_seg.backward()
            loss_seg_value += loss_seg.item()

            # train with target
            tgt_output = self.model(tgt_image)
            D_out = self.model_D(F.softmax(tgt_output, dim=0))

            loss_adv_target = self.bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(source_label)).cuda())
            loss_adv_target.backward()
            loss_adv_target_value += loss_adv_target.item()

            ## train D

            # bring back requires_grad
            for param in self.model_D.parameters():
                param.requires_grad = True

            # train with source
            src_output = src_output.detach()

            D_out = self.model_D(F.softmax(src_output, dim=0))

            loss_D = self.bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(source_label)).cuda())
            loss_D.backward()
            loss_D_value += loss_D.item()

            # train with source
            tgt_output = tgt_output.detach()
            D_out = self.model_D(F.softmax(tgt_output, dim=0))

            loss_D = self.bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(target_label)).cuda())
            loss_D.backward()
            loss_D_value += loss_D.item()

            self.optimizer.step()
            self.optimizer_D.step()

            tbar.set_description(
                'Seg_loss: %.3f d_loss: %.3f d_inv_loss: %.3f' %
                (loss_seg_value / (i + 1), loss_adv_target_value /
                 (i + 1), loss_D_value / (i + 1)))

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                image = torch.cat([src_image, tgt_image], dim=0)
                output = torch.cat([src_output, tgt_output], dim=0)
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, src_label, output,
                                             global_step)

        self.writer.add_scalar('train/Seg_loss', loss_seg_value, epoch)
        self.writer.add_scalar('train/d_loss', loss_adv_target_value, epoch)
        self.writer.add_scalar('train/d_inv_loss', loss_D_value, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' %
              (loss_seg_value + loss_adv_target_value + loss_D_value))

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, _ = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        # train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr}]

        # Define Optimizer
        # optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)


        optimizer = torch.optim.Adam(train_params,lr=args.lr,weight_decay=args.weight_decay,amsgrad=True)


        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:            
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            #self.model = self.model.to(torch.device('cuda'))
            #model.to(device)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # if next(model.parameters()).is_cuda:
        #     print(" ***************** it is running on cuda *****************")

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            #args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        
        # trying to save a checkpoint and check if it exists...
        # import os
        # cur_path = os.path.dirname(os.path.abspath('.'))
        # print('saving mycheckpoint in:' + cur_path )
        # checkpoint_name = 'mycheckpoint.pth.tar'
        # save_path = cur_path + '/' + checkpoint_name
        # torch.save(self.model.module.state_dict(), save_path)
        # assert(os.path.isfile(save_path))
        # # torch.save(self.model.module.state_dict(), checkpoint_name)
        # # assert(os.path.isfile(cur_path + '/' + checkpoint_name))
        # print('checkpoint saved ok')
        # # checkpoint saved

        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)       
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            #pdb.set_trace()
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            # print('saving checkpoint')
            is_best = False
                #model = model.to(torch.device('cuda'))
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            #model.to(torch.device('cuda'))
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #5
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = MyDeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        freeze_bn=args.freeze_bn)
        self.model = model
        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        #optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)

        # adam
        optimizer = torch.optim.Adam(params=self.model.parameters(),betas=(0.9, 0.999),
                                    eps=1e-08, weight_decay=0, amsgrad=False)

        weight = [1, 10, 10, 10, 10, 10, 10, 10]
        weight = torch.tensor(weight, dtype=torch.float) 
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda, num_classes=self.nclass).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            #patch_replication_callback(self.model)
            self.model = self.model.cuda()
        
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        '''
        # 获取当前模型各层的名称
        layer_name = list(self.model.state_dict().keys())
        #print(self.model.state_dict()[layer_name[3]])
        # 加载通用的预训练模型
        pretrained = './pretrained_model/deeplab-mobilenet.pth.tar'
        pre_ckpt = torch.load(pretrained)
        key_name = list(checkpoint['state_dict'].keys()) # 获取预训练模型各层的名称
        pre_ckpt['state_dict'][key_name[-2]] = checkpoint['state_dict'][key_name[-2]] # 类别不同,最后两层单独赋值
        pre_ckpt['state_dict'][key_name[-1]] = checkpoint['state_dict'][key_name[-1]]
        self.model.module.load_state_dict(pre_ckpt['state_dict'])     # , strict=False)
        #print(self.model.state_dict()[layer_name[3]])
        print("加载预训练模型ok")
        '''
    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            #import pdb
            #pdb.set_trace()
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            #if (i+1) % 50 == 0:
            #    print('Train loss: %.3f' % (loss.item() / (i + 1)))
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        filename='checkpoint_{}_{:.4f}.pth.tar'.format(epoch, train_loss)
        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best, filename=filename)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            #if (i+1) %20 == 0:
            #    print('Test loss: %.3f' % (loss / (i + 1)))
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #6
0
class MyTrainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}

        if (args.dataset == "fashion_clothes"):
            train_set = fashion.FashionDataset(
                args, Path.db_root_dir("fashion_clothes"), mode='train')
            val_set = fashion.FashionDataset(
                args, Path.db_root_dir("fashion_clothes"), mode='test')
            self.nclass = train_set.nclass

            print("Train size {}, val size {}".format(len(train_set),
                                                      len(val_set)))

            self.train_loader = DataLoader(dataset=train_set,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           **kwargs)
            self.val_loader = DataLoader(dataset=val_set,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         **kwargs)
            self.test_loader = None

            assert self.nclass == 7

        self.best_pred = 0.0

        if args.model == 'deeplabv3+':
            model = DeepLab(backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)

            if args.resume is not None:
                if not os.path.isfile(args.resume):
                    raise RuntimeError("=> no checkpoint found at '{}'".format(
                        args.resume))
                if args.cuda:
                    checkpoint = torch.load(args.resume)
                else:
                    checkpoint = torch.load(args.resume, map_location='cpu')
                args.start_epoch = checkpoint['epoch']

                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))

            #Freeze the backbone
            if args.freeze_backbone:
                set_parameter_requires_grad(model.backbone, False)

            ######NEW DECODER######
            #Different type of FT
            if args.ft_type == 'decoder':
                set_parameter_requires_grad(model, False)
                model.decoder = build_decoder(self.nclass, 'resnet',
                                              nn.BatchNorm2d)
                train_params = [{
                    'params': model.get_1x_lr_params(),
                    'lr': args.lr
                }, {
                    'params': model.get_10x_lr_params(),
                    'lr': args.lr * 10
                }]

            elif args.ft_type == 'last_layer':
                set_parameter_requires_grad(model, False)
                model.decoder.last_conv[8] = nn.Conv2d(
                    in_channels=256, out_channels=self.nclass, kernel_size=1)
                model.decoder.last_conv[8].reset_parameters()
                train_params = [{
                    'params': model.get_1x_lr_params(),
                    'lr': args.lr
                }, {
                    'params': model.get_10x_lr_params(),
                    'lr': args.lr * 10
                }]
            if args.ft_type == 'all':
                #Reset last layer, to generate output we want
                model.decoder.last_conv[8] = nn.Conv2d(
                    in_channels=256, out_channels=self.nclass, kernel_size=1)
                model.decoder.last_conv[8].reset_parameters()

                train_params = [{
                    'params': model.get_1x_lr_params(),
                    'lr': args.lr
                }, {
                    'params': model.get_10x_lr_params(),
                    'lr': args.lr * 10
                }]

            # Define Optimizer
            optimizer = torch.optim.SGD(train_params,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay,
                                        nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            weight = calculate_weigths_labels(args.dataset, self.train_loader,
                                              self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
            print("weight is {}".format(weight))
        else:
            weight = None

        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def visulize_validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            #current_index_val_set
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

            #we have image, target, output on GPU
            #j, index of image in batch

            self.summary.visualize_pregt(self.writer, self.args.dataset, image,
                                         target, output, i)

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Visualizing:')
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Final Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

    def output_validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            #current_index_val_set
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

            #we have image, target, output on GPU
            #j, index of image in batch

            #image save
            self.summary.save_pred(self.args.dataset, output, i)

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Visualizing:')
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Final Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

    def train_loop(self):
        try:
            for epoch in range(self.args.start_epoch, self.args.epochs):
                self.training(epoch)
                if not self.args.no_val and epoch % self.args.eval_interval == (
                        self.args.eval_interval - 1):
                    self.validation(epoch)
        except KeyboardInterrupt:
            print('Early Stopping')
        finally:
            self.visulize_validation()
            self.writer.close()
Example #7
0
class Evaluation(object):
    def __init__(self, args):

        self.args = args

        # Define Dataloader
        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        _, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.network == 'searched_dense':
            """ 40_5e_lr_38_31.91  """
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab',
                                     'genotype.npy')
            cell_arch = np.load(cell_path)
            network_arch = [0, 1, 2, 3, 2, 2, 2, 2, 1, 2, 3, 2]
            low_level_layer = 0

            model = Model_2(network_arch, cell_arch, self.nclass, args,
                            low_level_layer)

        elif args.network == 'searched_baseline':
            cell_path = os.path.join(args.saved_arch_path, 'searched_baseline',
                                     'genotype.npy')
            cell_arch = np.load(cell_path_1)
            network_arch = [0, 1, 2, 2, 3, 2, 2, 1, 2, 1, 1, 2]
            low_level_layer = 1
            model = Model_2_baseline(network_arch, cell_arch, self.nclass,
                                     args, low_level_layer)

        elif args.network.startswith('autodeeplab'):
            network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab',
                                     'genotype.npy')
            cell_arch = np.load(cell_path)
            low_level_layer = 2

            if args.network == 'autodeeplab-dense':
                model = Model_2(network_arch, cell_arch, self.nclass, args,
                                low_level_layer)

            elif args.network == 'autodeeplab-baseline':
                model = Model_2_baseline(network_arch, cell_arch, self.nclass,
                                         args, low_level_layer)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = nn.CrossEntropyLoss(weight=weight,
                                             ignore_index=255).cuda()
        self.model = model

        # Define Evaluator
        self.evaluator_1 = Evaluator(self.nclass)
        self.evaluator_2 = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                self.model.load_state_dict(checkpoint['state_dict'])

            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

    def validation(self):
        self.model.eval()
        self.evaluator_1.reset()
        self.evaluator_2.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        time_meter_1 = AverageMeter()
        time_meter_2 = AverageMeter()

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output_1, output_2 = self.model(image)

            loss_1 = self.criterion(output_1, target)
            loss_2 = self.criterion(output_2, target)

            pred_1 = torch.argmax(output_1, axis=1)
            pred_2 = torch.argmax(output_2, axis=1)

            # Add batch sample into evaluator
            self.evaluator_1.add_batch(target, pred_1)
            self.evaluator_2.add_batch(target, pred_2)

        mIoU_1 = self.evaluator_1.Mean_Intersection_over_Union()
        mIoU_2 = self.evaluator_2.Mean_Intersection_over_Union()

        print('Validation:')
        print("mIoU_1:{}, mIoU_2: {}".format(mIoU_1, mIoU_2))

    def testing_entropy(self):
        self.saver = Saver(self.args)
        self.saver.save_experiment_config()
        """ Define Tensorboard Summary """
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        self.model.eval()
        self.evaluator_1.reset()
        self.evaluator_2.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output_1, avg_confidence, max_confidence = self.model.forward_testing_entropy(
                    image)

            loss_1 = self.criterion(output_1, target)

            entropy = normalized_shannon_entropy(output_1)

            self.writer.add_scalar('avg_confidence/i', avg_confidence.item(),
                                   i)
            self.writer.add_scalar('max_confidence/i', max_confidence.item(),
                                   i)
            self.writer.add_scalar('entropy/i', entropy.item(), i)
            self.writer.add_scalar('loss/i', loss_1.item(), i)

            self.summary.visualize_image(self.writer, self.args.dataset, image,
                                         target_show, output_2, global_step)

        print('testing confidence')
        self.writer.close()

    def dynamic_inference(self,
                          entropy=False,
                          confidence_mode=False,
                          pool_threshold=False,
                          entropy_threshold=False):
        self.model.eval()
        self.evaluator_1.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output, earlier_exit, confidence = self.model.forward_dynamic_inference(image, \
                                                    entropy, confidence_mode, pool_threshold, entropy_threshold)

            loss = self.criterion(output, target)

            pred = torch.argmax(output, axis=1)

            # Add batch sample into evaluator
            self.evaluator_1.add_batch(target, pred)

        mIoU = self.evaluator_1.Mean_Intersection_over_Union()

        print('Validation:')
        print("mIoU_1:".format(mIoU_1))

    def mac(self):
        self.model.eval()
        with torch.no_grad():
            flops, params = get_model_complexity_info(
                self.model, (3, 1025, 2049),
                as_strings=True,
                print_per_layer_stat=False)
            print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
            print('{:<30}  {:<8}'.format('Number of parameters: ', params))
Example #8
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        # kwargs = {'num_workers': args.workers, 'pin_memory': True}
        kwargs = {'num_workers': 0, 'pin_memory': True}

        if args.nir:
            input_channels = 4
        else:
            input_channels = 3

        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)
        # Define network
        model = DeepLab(num_classes=3,
                        backbone=args.backbone,
                        in_channels=input_channels,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
                weight[1] = 4
                weight[2] = 2
                weight[0] = 1
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                # place_holder_target = target
                # place_holder_output = output
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def calculate_scores(self):
        train_loss = 0.0
        self.model.eval()

        loader = self.val_loader
        shape = self.model.module.backbone.layer4[2].conv2.weight.shape
        gradient_embeddings = np.zeros((shape[0] * shape[1], len(loader)))
        # gradient_embeddings = np.zeros((512*33*33, len(loader)))
        # tbar = tqdm(self.train_loader)
        tbar = tqdm(loader)
        num_img_tr = len(loader)
        for i, sample in enumerate(tbar):
            # activations = collections.defaultdict(list)
            # def save_activation(name, mod, input, output):
            #     activations[name].append(output.cpu())
            #
            # for name, m in self.model.named_modules():
            #     if name == 'module.backbone.layer4.2.relu':
            #         m.register_forward_hook(partial(save_activation, name))
            #
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            # self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            # score = activations['module.backbone.layer4.2.relu'][0].data.numpy().flatten()
            score = self.model.module.backbone.layer4[2].conv2.weight.grad
            score = score.cpu().data.numpy()[:, :, 0, 0].flatten()
            gradient_embeddings[:, i] = score
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr)
        with open('feature_maps/backbone_layer_4_conv2_gradients_trained.npy',
                  'wb') as f:
            print('[INFO] saving the gradients into hard disk')
            np.save(f, gradient_embeddings)
        # with open('gradients.npy', 'rb') as f:
        #     print('[INFO] Loading the gradients from hard disk')
        #     gradient_embeddings = np.load(f)
        # print('[INFO] Calculating TSNE embeddings')
        #
        # embeddings = tsne.fit(gradient_embeddings)
        # # embeddings = TSNE(n_components=2).fit_transform(gradient_embeddings)
        # print('[INFO] Plotting the embeddings')
        # plt.scatter(embeddings[0], embeddings[1])
        # plt.show()
        self.writer.add_scalar('total_score', train_loss)
        # print('[Epoch: %d, numImages: %5d]' % (i, i * self.args.batch_size + image.data.shape[0]))
        # print('Loss: %.3f' % train_loss)

    def pred_single_image(self, path, counter):
        self.model.eval()
        img_path = path
        lbl_path = os.path.join(
            os.path.split(os.path.split(path)[0])[0], 'lbl',
            os.path.split(path)[1])
        activations = collections.defaultdict(list)

        def save_activation(name, mod, input, output):
            activations[name].append(output.cpu())

        for name, m in self.model.named_modules():
            if type(m) == nn.ReLU:
                m.register_forward_hook(partial(save_activation, name))

        input = cv2.imread(path)
        # bkg = cv2.createBackgroundSubtractorMOG2()
        # back = bkg.apply(input)
        # cv2.imshow('back', back)
        # cv2.waitKey()
        input = cv2.resize(input, (513, 513), interpolation=cv2.INTER_CUBIC)
        image = Image.open(img_path).convert('RGB')  # width x height x 3
        # _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _tmp = np.array(Image.open(img_path), dtype=np.uint8)
        _tmp[_tmp == 255] = 1
        _tmp[_tmp == 0] = 0
        _tmp[_tmp == 128] = 2
        _tmp = Image.fromarray(_tmp)

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=513),
            tr.Normalize(mean=mean, std=std),
            tr.ToTensor()
        ])
        sample = {'image': image, 'label': _tmp}
        sample = composed_transforms(sample)

        image, target = sample['image'], sample['label']

        image = torch.unsqueeze(image, dim=0)
        if self.args.cuda:
            image, target = image.cuda(), target.cuda()
        with torch.no_grad():
            output = self.model(image)

        see = Analysis('module.decoder.last_conv.6', activations)
        pred = output.data.cpu().numpy()
        target = target.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        pred = np.reshape(pred, (513, 513))
        # prediction = np.append(target, pred, axis=1)
        prediction = pred

        rgb = np.zeros((prediction.shape[0], prediction.shape[1], 3))

        r = prediction.copy()
        g = prediction.copy()
        b = prediction.copy()

        g[g != 1] = 0
        g[g == 1] = 255

        r[r != 2] = 0
        r[r == 2] = 255
        b = np.zeros(b.shape)

        rgb[:, :, 0] = b
        rgb[:, :, 1] = g
        rgb[:, :, 2] = r

        prediction = np.append(input, rgb.astype(np.uint8), axis=1)
        result = np.append(input, prediction.astype(np.uint8), axis=1)
        cv2.line(rgb, (513, 0), (513, 1020), (255, 255, 255), thickness=1)
        cv2.line(rgb, (513, 0), (513, 1020), (255, 255, 255), thickness=1)
        cv2.imwrite(
            '/home/robot/git/pytorch-deeplab-xception/run/cropweed/deeplab-resnet/experiment_41/samples/synthetic_{}.png'
            .format(counter), prediction)
Example #9
0
class Infer(object):
    def __init__(self, args, ori_img_lst, init_mask_lst):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.ori_img_lst = ori_img_lst
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.test_loader, self.nclass = make_data_loader_demo(
            args, args.test_folder, ori_img_lst, init_mask_lst, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn,
                        use_iou=args.use_maskiou)

        self.model = model

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'],
                                                  strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'],
                                           strict=False)
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def test(self, epoch):
        self.model.eval()
        tbar = tqdm(self.test_loader, desc='\r')
        images = []
        ious = np.zeros([self.test_loader.__len__(), self.nclass])
        for i, sample in enumerate(tbar):
            image = sample['image']
            if self.args.cuda:
                image = image.cuda()
            with torch.no_grad():
                output, output_iou = self.model(image)
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            images.append(pred.squeeze(axis=0))
            if self.args.use_maskiou:
                output_iou = output_iou.data.cpu().numpy()
                output_iou[output_iou == 0] = np.nan
                ious[i, :] = output_iou
        predictions = {'images': images, 'ious': ious}
        self.saver.save_demo_result(predictions, self.ori_img_lst,
                                    self.args.test_folder)

        print('\nTest:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.test_batch_size + image.data.shape[0]))
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader1, self.train_loader2, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab (self.nclass, 12, self.criterion)
        optimizer = torch.optim.SGD(
                model.parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )
        self.model, self.optimizer = model, optimizer

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print ('cuda finished')


        # Define Optimizer

        
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader1))

        self.architect = Architect (self.model, args)
#        print(self.model.arch_parameters()[2])
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        
        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        for name, para in self.model._parameters.items():
            print(name)
Example #11
0
class Server(segmentation_pb2_grpc.SegmentationServicer):
    def __init__(self, *args, **kwargs):
        super(Server, self).__init__(*args, **kwargs)
        args = parser.parse_known_args()

        args.cuda = not args.no_cuda and torch.cuda.is_available()
        if args.cuda:
            try:
                args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
            except ValueError:
                raise ValueError(
                    'Argument --gpu_ids must be a comma-separated list of integers only'
                )

        if args.sync_bn is None:
            if args.cuda and len(args.gpu_ids) > 1:
                args.sync_bn = True
            else:
                args.sync_bn = False

        # default settings for epochs, batch_size and lr
        if args.epochs is None:
            epoches = {
                'coco': 30,
                'cityscapes': 200,
                'pascal': 50,
            }
            args.epochs = epoches[args.dataset.lower()]

        if args.batch_size is None:
            args.batch_size = 4 * len(args.gpu_ids)

        if args.test_batch_size is None:
            args.test_batch_size = args.batch_size

        if args.lr is None:
            lrs = {
                'coco': 0.1,
                'cityscapes': 0.01,
                'pascal': 0.007,
            }
            args.lr = lrs[args.dataset.lower()] / (
                4 * len(args.gpu_ids)) * args.batch_size

        if args.checkname is None:
            args.checkname = 'deeplab-' + str(args.backbone)
        print(args)
        torch.manual_seed(args.seed)

        self.initialize_model(args)

    def initialize_model(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(self.args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        # Define Dataloader
        kwargs = {'num_worker': self.args.worker, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            self.args, **kwargs)
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.args.backbone,
                        output_stride=self.args.out_stride,
                        sync_bn=self.args.sync_bn,
                        freeze_bn=self.args.freeze_bn)
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if self.args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        if not os.path.isfile(self.args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args.resume))
        checkpoint = torch.load(self.args.resume)
        self.args.start_epoch = checkpoint['epoch']
        if self.args.cuda:
            self.model.module.load_state_dict(checkpoint['state_dict'])
        else:
            self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()
        self.evaluator.reset()

    def process(self, request):

        if not hasattr(request, "image_encoded"):
            raise ValueError(
                "One of current request doesn't contain encoded image.")

        image_encoded = request.image_encoded

        def _decode_image(content):
            image_parser = ImageFile.Parser()
            image_parser.feed(content)
            return image_parser.close()

        image = transforms.ToTensor(_decode_image(image_encoded))
        if self.args.cuda:
            image = image.cuda()
        with torch.no_grad():
            output = self.model(image)

        return output

    def recvFeature(self, request, context):
        output = self.process(request)
        pred = output.data.cpu().numpy()
        pred = Image.fromarray(np.argmax(pred, axis=1), mode='L')
        request.image_segmentation_class_encoded = pred.tobytes()

        return request

    def recvFeatures(self, request, context):
        for r in request:
            output = self.process(r)
            pred = output.data.cpu().numpy()
            pred = Image.fromarray(np.argmax(pred, axis=1), mode='L')
            r.image_segmentation_class_encoded = pred.tobytes()

        return request
Example #12
0
class Trainer(object):
    def __init__(self,
                 batch_size=32,
                 optimizer_name="Adam",
                 lr=1e-3,
                 weight_decay=1e-5,
                 epochs=200,
                 model_name="model01",
                 gpu_ids=None,
                 resume=None,
                 tqdm=None):
        """
        args:
            batch_size = (int) batch_size of training and validation
            lr = (float) learning rate of optimization
            weight_decay = (float) weight decay of optimization
            epochs = (int) The number of epochs of training
            model_name = (string) The name of training model. Will be folder name.
            gpu_ids = (List) List of gpu_ids. (e.g. gpu_ids = [0, 1]). Use CPU, if it is None. 
            resume = (Dict) Dict of some settings. (resume = {"checkpoint_path":PATH_of_checkpoint, "fine_tuning":True or False}). 
                     Learn from scratch, if it is None.
            tqdm = (tqdm Object) progress bar object. Set your tqdm please.
                   Don't view progress bar, if it is None.
        """
        # Set params
        self.batch_size = batch_size
        self.epochs = epochs
        self.start_epoch = 0
        self.use_cuda = (gpu_ids is not None) and torch.cuda.is_available
        self.tqdm = tqdm
        self.use_tqdm = tqdm is not None
        # ------------------------- #
        # Define Utils. (No need to Change.)
        """
        These are Project Modules.
        You may not have to change these.
        
        Saver: Save model weight. / <utils.saver.Saver()>
        TensorboardSummary: Write tensorboard file. / <utils.summaries.TensorboardSummary()>
        Evaluator: Calculate some metrics (e.g. Accuracy). / <utils.metrics.Evaluator()>
        """
        ## ***Define Saver***
        self.saver = Saver(model_name, lr, epochs)
        self.saver.save_experiment_config()

        ## ***Define Tensorboard Summary***
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # ------------------------- #
        # Define Training components. (You have to Change!)
        """
        These are important setting for training.
        You have to change these.
        
        make_data_loader: This creates some <Dataloader>s. / <dataloader.__init__>
        Modeling: You have to define your Model. / <modeling.modeling.Modeling()>
        Evaluator: You have to define Evaluator. / <utils.metrics.Evaluator()>
        Optimizer: You have to define Optimizer. / <utils.optimizer.Optimizer()>
        Loss: You have to define Loss function. / <utils.loss.Loss()>
        """
        ## ***Define Dataloader***
        self.train_loader, self.val_loader, self.test_loader, self.num_classes = make_data_loader(
            batch_size)

        ## ***Define Your Model***
        self.model = Modeling(self.num_classes)

        ## ***Define Evaluator***
        self.evaluator = Evaluator(self.num_classes)

        ## ***Define Optimizer***
        self.optimizer = Optimizer(self.model.parameters(),
                                   optimizer_name=optimizer_name,
                                   lr=lr,
                                   weight_decay=weight_decay)

        ## ***Define Loss***
        self.criterion = Loss()

        # ------------------------- #
        # Some settings
        """
        You don't have to touch bellow code.
        
        Using cuda: Enable to use cuda if you want.
        Resuming checkpoint: You can resume training if you want.
        """
        ## ***Using cuda***
        if self.use_cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=gpu_ids).cuda()

        ## ***Resuming checkpoint***
        """You can ignore bellow code."""
        self.best_pred = 0.0
        if resume is not None:
            if not os.path.isfile(resume["checkpoint_path"]):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    resume["checkpoint_path"]))
            checkpoint = torch.load(resume["checkpoint_path"])
            self.start_epoch = checkpoint['epoch']
            if self.use_cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if resume["fine_tuning"]:
                # resume params of optimizer, if run fine tuning.
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.start_epoch = 0
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume["checkpoint_path"], checkpoint['epoch']))

    def _run_epoch(self,
                   epoch,
                   mode="train",
                   leave_progress=True,
                   use_optuna=False):
        """
        run training or validation 1 epoch.
        You don't have to change almost of this method.
        
        args:
            epoch = (int) How many epochs this time.
            mode = {"train" or "val"}
            leave_progress = {True or False} Can choose whether leave progress bar or not.
            use_optuna = {True or False} Can choose whether use optuna or not.
        
        Change point (if you need):
        - Evaluation: You can change metrics of monitoring.
        - writer.add_scalar: You can change metrics to be saved in tensorboard.
        """
        # ------------------------- #
        leave_progress = leave_progress and not use_optuna
        # Initializing
        epoch_loss = 0.0
        ## Set model mode & tqdm (progress bar; it wrap dataloader)
        assert (mode == "train") or (
            mode == "val"
        ), "argument 'mode' can be 'train' or 'val.' Not {}.".format(mode)
        if mode == "train":
            data_loader = self.tqdm(
                self.train_loader,
                leave=leave_progress) if self.use_tqdm else self.train_loader
            self.model.train()
            num_dataset = len(self.train_loader)
        elif mode == "val":
            data_loader = self.tqdm(
                self.val_loader,
                leave=leave_progress) if self.use_tqdm else self.val_loader
            self.model.eval()
            num_dataset = len(self.val_loader)
        ## Reset confusion matrix of evaluator
        self.evaluator.reset()

        # ------------------------- #
        # Run 1 epoch
        for i, sample in enumerate(data_loader):
            ## ***Get Input data***
            inputs, target = sample["input"], sample["label"]
            if self.use_cuda:
                inputs, target = inputs.cuda(), target.cuda()

            ## ***Calculate Loss <Train>***
            if mode == "train":
                self.optimizer.zero_grad()
                output = self.model(inputs)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
            ## ***Calculate Loss <Validation>***
            elif mode == "val":
                with torch.no_grad():
                    output = self.model(inputs)
                loss = self.criterion(output, target)
            epoch_loss += loss.item()
            ## ***Report results***
            if self.use_tqdm:
                data_loader.set_description('{} loss: {:.3f}'.format(
                    mode, epoch_loss / (i + 1)))
            ## ***Add batch results into evaluator***
            target = target.cpu().numpy()
            output = torch.argmax(output, axis=1).data.cpu().numpy()
            self.evaluator.add_batch(target, output)

        ## **********Evaluate Score**********
        """You can add new metrics! <utils.metrics.Evaluator()>"""
        Acc = self.evaluator.Accuracy()

        if not use_optuna:
            ## ***Save eval into Tensorboard***
            self.writer.add_scalar('{}/loss_epoch'.format(mode),
                                   epoch_loss / (i + 1), epoch)
            self.writer.add_scalar('{}/Acc'.format(mode), Acc, epoch)
            print('Total {} loss: {:.3f}'.format(mode,
                                                 epoch_loss / num_dataset))
            print("{0} Acc:{1:.2f}".format(mode, Acc))

        # Return score to watch. (update checkpoint or optuna's objective)
        return Acc

    def run(self, leave_progress=True, use_optuna=False):
        """
        Run all epochs of training and validation.
        """
        for epoch in tqdm(range(self.start_epoch, self.epochs)):
            print(pycolor.GREEN + "[Epoch: {}]".format(epoch) + pycolor.END)

            ## ***Train***
            print(pycolor.YELLOW + "Training:" + pycolor.END)
            self._run_epoch(epoch,
                            mode="train",
                            leave_progress=leave_progress,
                            use_optuna=use_optuna)
            ## ***Validation***
            print(pycolor.YELLOW + "Validation:" + pycolor.END)
            score = self._run_epoch(epoch,
                                    mode="val",
                                    leave_progress=leave_progress,
                                    use_optuna=use_optuna)
            print("---------------------")
            if score > self.best_pred:
                print("model improve best score from {:.4f} to {:.4f}.".format(
                    self.best_pred, score))
                self.best_pred = score
                self.saver.save_checkpoint({
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    self.model.state_dict(),
                    'optimizer':
                    self.optimizer.state_dict(),
                    'best_pred':
                    self.best_pred,
                })
        self.writer.close()
        return self.best_pred
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        # when model is initialized, the track_running_stats is true so the running_mean and running_var is
        # initialized and loaded the pretrained model. Since the model uses the batch stats during training mode,
        # the optimization is easier while the running stats will be used for eval mode.
        # Using batch stats makes optimization easier
        for child in model.modules():
            if type(child) == nn.BatchNorm2d:
                child.track_running_stats = False  # use batch stats for train and eval modes;
                # if running stats are not None, they are still updated
                # in such toy example, we do not use running stats!!!
            if type(child) == nn.Dropout:
                child.p = 0  # no dropout

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        if args.densecrfloss > 0:
            self.densecrflosslayer = DenseCRFLoss(
                weight=args.densecrfloss,
                sigma_rgb=args.sigma_rgb,
                sigma_xy=args.sigma_xy,
                scale_factor=args.rloss_scale)
            print(self.densecrflosslayer)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        self.trainLoss = []
        self.miou = []
        self.mean_entropy = []
        self.mid_entropy = []
        self.celoss = []
        self.crfloss = []

    def training(self, epoch, args):
        train_loss = 0.0
        train_celoss = 0.0
        train_crfloss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)  #number of batches
        softmax = nn.Softmax(dim=1)
        for i, sample in enumerate(tbar):
            image, target, gt = sample['image'], sample['label'], sample[
                'groundtruth']
            croppings = (target != 254).float()
            if self.args.cuda:
                croppings = croppings.cuda()
            target[target == 254] = 255
            #target[target==255]=0 #only for full CE
            gt[gt ==
               255] = 0  # gt is used for affinity matrix, no unsure regions needed
            # Pixels labeled 255 are those unlabeled pixels. Padded region are labeled 254.
            # see function RandomScaleCrop in dataloaders/custom_transforms.py for the detail in data preprocessing
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            outputT = self.model(image)

            #miou
            output_miou = outputT.clone().detach().cpu().numpy()
            output_miou = np.argmax(output_miou, axis=1)
            gt_miou = gt.clone().numpy()

            self.evaluator.reset()
            self.evaluator.add_batch(gt_miou, output_miou)
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            self.miou.append(mIoU)

            celoss = self.criterion(outputT, target)

            if self.args.densecrfloss == 0:
                loss = celoss
            else:
                T = 1.0
                output = outputT / T

                #entropy calculation
                logsoftmax = nn.LogSoftmax(dim=1)
                softmax = nn.Softmax(dim=1)
                logp = logsoftmax(output)
                p = softmax(output)
                logp = logp.cpu().detach().numpy()
                p = p.cpu().detach().numpy()
                entropy = np.sum(-p * logp, axis=1)

                self.mean_entropy.append(np.mean(entropy[0]).item())
                self.mid_entropy.append(np.median(entropy[0]).item())

                #if epoch<=30:
                #	pass
                #else:
                #	h = output.register_hook(Znormalization)
                #probs = softmax(output)
                #denormalized_image = denormalizeimage(sample['image'], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                #densecrfloss = self.densecrflosslayer(denormalized_image,probs,croppings)
                #gt_tensor = gt.unsqueeze(1).repeat(1,3,1,1)
                gt_tensor = torch.zeros_like(output)
                gt_tensor[0, 0, gt[0, ...] == 0] = 1
                gt_tensor[1, 0, gt[0, ...] == 0] = 1
                gt_tensor[0, 1, gt[1, ...] == 1] = 1
                gt_tensor[1, 1, gt[1, ...] == 1] = 1
                gt_tensor = gt_tensor.cuda()

                tempreture = 1.0
                ####################################################################################################
                #element-wise log

                logsoftmax = nn.LogSoftmax(dim=1)
                logS = logsoftmax(output)

                part2 = torch.logsumexp(output, dim=1, keepdim=True)
                part1 = torch.logsumexp(output[:, 1:, :, :],
                                        dim=1,
                                        keepdim=True)

                for d in range(1, 20):
                    newtmp = torch.cat(
                        (output[:, :d, :, :], output[:, d + 1:, :, :]), dim=1)
                    newtmp2 = torch.logsumexp(newtmp, dim=1, keepdim=True)
                    part1 = torch.cat((part1, newtmp2), dim=1)
                part1 = torch.cat(
                    (part1,
                     torch.logsumexp(output[:, :20, :, :], dim=1,
                                     keepdim=True)),
                    dim=1)

                log1_S = part1 - part2

                # element-wise log implementation2
                #probs = softmax(output)

                densecrfloss = self.densecrflosslayer(
                    gt_tensor, logS, log1_S, croppings)  # use groundtruth

                ######################################################################################################

                ##### class variance regularizer #####
                '''
                variance = 0
                count = 0
                S1num = (gt[0]==0).sum()
                S2num = (gt[0]==1).sum()
                for i in range(output.size()[0]): # i stands for batch
                	variance += torch.sum(torch.var(output[i,:,gt[i]==0],dim=1))
                	variance += torch.sum(torch.var(output[i,:,gt[i]==1],dim=1))
                	count += 1
                Variance = args.densecrfloss * variance / count
                
                loss = celoss + Variance
                '''
                ######################################

                if self.args.cuda:
                    densecrfloss = densecrfloss.cuda()
                loss = celoss + densecrfloss
                train_crfloss += densecrfloss.item()
                #train_crfloss += Variance.item()
            loss.backward()

            self.optimizer.step()
            train_loss += loss.item()
            train_celoss += celoss.item()

            tbar.set_description(
                'Train loss: %.3f = CE loss %.3f + CRF loss: %.3f' %
                (train_loss / (i + 1), train_celoss / (i + 1), train_crfloss /
                 (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * ?(3) inference results each epoch
            if False:  #i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, outputT,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        self.trainLoss.append(train_loss)
        self.celoss.append(train_celoss)
        self.crfloss.append(train_crfloss)

        #if self.args.no_val:
        if self.args.save_interval:
            # save checkpoint every interval epoch
            is_best = False
            if (epoch + 1) % self.args.save_interval == 0:
                self.saver.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    },
                    is_best,
                    filename='checkpoint_epoch_{}.pth.tar'.format(
                        str(epoch + 1)))

    def validation(self, epoch):
        self.model.eval(
        )  # running stats is still updating now but we just do not use them
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            #target[target==254]=255
            target[target ==
                   255] = 0  #only for groundtruth affinity toy experiment
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)

            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        #print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        #print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

        return new_pred
Example #14
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        self.saver = Saver(args)
        self.saver.save_experiment_config()

        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        self.model = OCRNet(self.nclass)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay,
                                         nesterov=args.nesterov)
        if args.use_balanced_weights:
            weight = torch.tensor([0.2, 0.8], dtype=torch.float32)
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.evaluator = Evaluator(self.nclass)
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        if args.cuda:
            self.model = self.model.cuda()

        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output, _ = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()

        print('[Epoch:{},num_images:{}]'.format(
            epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss:{}'.format(train_loss))

        if self.args.nu_val:
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                _, output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        road_iou, mIOU = self.evaluator.Mean_Intersection_over_Union()
        FWIOU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print('Validation:\n')
        print('[Epoch:{},num_image:{}]'.format(
            epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Acc:{},Acc_class:{},mIOU:{},road_iou:{},fwIOU:{}'.format(
            Acc, Acc_class, mIOU, road_iou, FWIOU))
        print('Loss:{}'.format(test_loss))

        new_pred = road_iou
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #15
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define TensorBoard summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        if not args.test:
            self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclasses = make_data_loader(
            args)

        if self.args.task == 'segmentation':
            self.vs = Vs(args.dataset)
            self.evaluator = Evaluator(self.nclasses['val'])

        # Define Network
        model = Model(args, self.nclasses['train'], self.nclasses['val'])

        if args.model == None:
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        elif 'deeplab' in args.model:
            train_params = [{
                'params': model.backbone.parameters(),
                'lr': args.lr
            }, {
                'params': model.deeplab.parameters(),
                'lr': args.lr * 10
            }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_index)

        self.model, self.optimizer, self.criterion = model, optimizer, criterion

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Loading Classifier (SPNet style)
        if args.call is None or args.cseen is None or args.cunseen is None:
            raise NotImplementedError(
                "Classifiers for 'all', 'seen', 'unseen' should be loaded")
        else:
            if not os.path.isfile(args.cseen):
                raise RuntimeError(
                    "=> no checkpoint for clasifier found at '{}'".format(
                        args.classifier))

            self.model.load_train(args.cseen)

            if args.test_set == 'unseen':
                ctest = args.cunseen
            elif args.test_set == 'seen':
                ctest = args.cseen
            else:
                ctest = args.call

            if not os.path.isfile(ctest):
                raise RuntimeError(
                    "=> no checkpoint for clasifier found at '{}'".format(
                        ctest))

            self.model.load_test(ctest)

            print("Classifiers checkpoint successfully loaded from {}, {}".
                  format(args.cseen, ctest))

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("{}: No such checkpoint exists".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            pretrained_dict = checkpoint['state_dict']
            model_dict = {}
            state_dict = self.model.state_dict()
            for k, v in pretrained_dict.items():
                if 'classifier' in k: continue
                if k in state_dict:
                    model_dict[k] = v
            state_dict.update(model_dict)
            self.model.load_state_dict(state_dict)

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.best_pred = checkpoint['best_pred']
            print("Loading {} (epoch {}) successfully done".format(
                args.resume, checkpoint['epoch']))

        # Using CUDA
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()
        else:
            raise RuntimeError("CUDA SHOULD BE SUPPORTED")

        if args.ft:
            args.start_epoch = 0

    def train(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'].cuda(), sample['label']
            target = target.cuda().long()
            trues = torch.from_numpy(np.array([True] * image.shape[0])).cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image, trues)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print("[epoch: %d, loss: %.3f]" % (epoch, train_loss))

        if self.args.no_val:
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def val(self, epoch):
        if self.args.task == 'classification':
            top1 = AverageMeter('Acc@1', ':6.2f')
            top5 = AverageMeter('Acc@5', ':6.2f')
        elif self.args.task == 'segmentation':
            self.evaluator.reset()

        self.model.eval()

        tbar = tqdm(self.val_loader, desc='\r')
        miou = 0.0
        count = 0
        for i, sample in enumerate(tbar):
            images, targets, names = sample['image'].cuda(
            ), sample['label'].cuda().long(), sample['name']
            falses = torch.from_numpy(np.array([False] *
                                               images.shape[0])).cuda()
            with torch.no_grad():
                outputs = self.model(images, falses)

            loss = self.criterion(outputs, targets)
            #test_loss += loss.item()

            # Score record
            if self.args.task == 'classification':
                acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
            elif self.args.task == 'segmentation':
                preds = torch.argmax(outputs, axis=1)
                count += preds.shape[0]
                miou += get_iou(preds,
                                targets,
                                n_classes=self.nclasses['test'],
                                ignore0=self.args.test_set == 'seen'
                                or self.args.test_set == 'unseen')
                #self.evaluator.add_batch(targets.cpu().numpy(), preds.cpu().numpy())
                tbar.set_description('mIoU: %.3f' % (miou / count))

        # Fast test during the training

        if self.args.task == 'classification':
            _top1 = top1.avg
            _top5 = top5.avg
            self.writer.add_scalar('val/top1', _top1, epoch)
            self.writer.add_scalar('val/top5', _top5, epoch)
            print("Top-1: %.3f, Top-5: %.3f" % (_top1, _top5))
            new_score = _top1
        elif self.args.task == 'segmentation':
            self.writer.add_scalar('val/total_miou', miou, epoch)
            print('mIoU: %.3f' % miou)
            '''
			acc = self.evaluator.Pixel_Accuracy()
			acc_class = self.evaluator.Pixel_Accuracy_Class()
			miou = self.evaluator.Mean_Intersection_over_Union()
			fwiou = self.evaluator.Frequency_Weighted_Intersection_over_Union()
			self.writer.add_scalar('val/Acc', acc, epoch)
			self.writer.add_scalar('val/Acc_class', acc_class, epoch)
			self.writer.add_scalar('val/mIoU', miou, epoch)
			self.writer.add_scalar('val/fwIoU', fwiou, epoch)
			print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU:{}".format(acc, acc_class, miou, fwiou))
			'''
            new_score = miou

        if new_score >= self.best_pred:
            is_best = True
            self.best_pred = float(new_score)
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #16
0
class evaluation(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.train_hard_mining_loader, self.val_loader, self.val_save_loader, self.arg_loader, self.test_loader, self.val_loader_for_compare, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=True,
                        freeze_bn=args.freeze_bn)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    # evaluate the model on validation dataset
    def validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Validation:')
        print('[numImages: %5d]' %
              (i * self.args.batch_size + image.data.shape[0]))
        # print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print("Acc:", Acc)
        print("Acc_class:", Acc_class)
        print("mIoU:", mIoU)
        print("fwIoU:", FWIoU)
        print('Loss: %.3f' % test_loss)

    # save the segmentation of test datasets
    # change the target direction in pascal.py
    def test_save(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.test_loader, desc='\r')
        for i, sample in enumerate(tbar):
            image = sample[0]
            image_id = sample[1]
            if self.args.cuda:
                image = image.cuda()
            with torch.no_grad():
                output = self.model(image)
            prediction = output.data.max(1)[1].squeeze_(1).squeeze_(
                0).cpu().numpy()
            prediction = prediction.astype('uint8')
            im = PIL.Image.fromarray(prediction)
            im.save(image_id[0])

    # save the segmentation of validation datasets in original size
    # need to change the direction here
    def validation_save(self):
        self.model.eval()
        self.evaluator.reset()
        filedir = 'C:\\Users\\Shuang\\Desktop\\val_res'
        tbar = tqdm(self.val_save_loader, desc='\r')
        for i, sample in enumerate(tbar):
            image, target, image_id = sample['image'], sample['label'], sample[
                'id']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            prediction = output.data.max(1)[1].squeeze_(1).squeeze_(
                0).cpu().numpy()
            im = PIL.Image.fromarray(prediction.astype('uint8'))
            h = target.shape[1]
            w = target.shape[2]
            ratio = 513. / np.max([w, h])
            if w < h:
                m = int(w * ratio)
                im = im.crop((0, 0, m, 513))
            elif w >= h:
                m = int(h * ratio)
                im = im.crop((0, 0, 513, m))
            im = im.resize((w, h), PIL.Image.BILINEAR)

            if not os.path.isdir(filedir):
                os.makedirs(filedir)
            im.save(os.path.join(filedir, image_id[0] + ".png"))

    def validation_resize(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_save_loader, desc='\r')
        for i, sample in enumerate(tbar):
            image, target, image_id = sample['image'], sample['label'], sample[
                'id']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            prediction = output.data.max(1)[1].squeeze_(1).squeeze_(
                0).cpu().numpy()
            im = PIL.Image.fromarray(prediction.astype('uint8'))
            h = target.shape[1]
            w = target.shape[2]
            ratio = 513. / np.max([w, h])
            if w < h:
                m = int(w * ratio)
                im = im.crop((0, 0, m, 513))
            elif w >= h:
                m = int(h * ratio)
                im = im.crop((0, 0, 513, m))
            im = im.resize((w, h), PIL.Image.BILINEAR)

            if not os.path.isdir(filedir):
                os.makedirs(filedir)
            im.save(os.path.join(filedir, image_id[0] + ".png"))

    # calculate the MIoU of the result and label
    # need to change the direction in pascal.py
    def compare(self):

        tbar = tqdm(self.val_loader_for_compare, desc='\r')
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image = image.numpy().astype(np.int64)
            target = target.numpy().astype(np.float32)
            self.evaluator.add_batch(target, image)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Compare the result and label:')
        print('[numImages: %5d]' %
              (i * self.args.batch_size + image.data.shape[0]))
        # print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print("Acc:", Acc)
        print("Acc_class:", Acc_class)
        print("mIoU:", mIoU)
        print("fwIoU:", FWIoU)

    # hard mining and change the train list of next epoch
    def hard_mining(self):
        iou_id = []
        tbar = tqdm(self.val_loader_for_compare, desc='\r')
        for i, sample in enumerate(tbar):
            image, target, image_id = sample['image'], sample['label'], sample[
                'id']
            image = image.numpy().astype(np.int64)
            target = target.numpy().astype(np.float32)
            self.evaluator.one_add_batch(target, image)
            IoU = self.evaluator.One_Intersection_over_Union()
            IoU = float(IoU)
            iou_id.append([IoU, image_id])

        iou_id.sort()
        print(iou_id)
        filename = 'F:/pingan/VOCdevkit/VOC2012/ImageSets/Segmentation/arg1.txt'
        if not os.path.exists(filename):
            os.system(r'touch %s' % filename)
        f = open(filename, 'w')
        for i in range(10):
            f.write(iou_id[i][1][0] + "\n")
        f.close()
Example #17
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True}
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                #if so, which trainloader to use?
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab (num_classes=self.nclass, num_layers=12, criterion=self.criterion, filter_multiplier=self.args.filter_multiplier,
                             block_multiplier=self.args.block_multiplier, step=self.args.step)
        optimizer = torch.optim.SGD(
                model.weight_parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(self.model.arch_parameters(),
                                                    lr=args.arch_lr, betas=(0.9, 0.999),
                                                    weight_decay=args.arch_weight_decay)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loaderA), min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()


        # mixed precision
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape, dtype=module.running_var.dtype,
                                           device=module.running_var.device), requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape, dtype=module.running_var.dtype,
                                            device=module.running_var.device), requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer, self.architect_optimizer] = amp.initialize(
                self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level,
                keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic")

            print('cuda finished')


        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) >1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        #checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.state_dict(), checkpoint['state_dict'])


            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(), checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()

            if epoch >= self.args.alpha_epoch:
                search = next(iter(self.train_loaderB))
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda (), target_search.cuda ()

                self.architect_optimizer.zero_grad()
                output_search = self.model(image_search)
                arch_loss = self.criterion(output_search, target_search)
                if self.use_amp:
                    with amp.scale_loss(arch_loss, self.architect_optimizer) as arch_scaled_loss:
                        arch_scaled_loss.backward()
                else:
                    arch_loss.backward()
                self.architect_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            #self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

            #torch.cuda.empty_cache()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #18
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        self.half = args.half

        self.prev_pred = 0.0
        self.bad_count = 0

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        #train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}, {'params': model.get_10x_lr_params(), 'lr': args.lr*10}]
        #optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)

        optimizer = torch.optim.SGD(params=model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        #optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))
        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0

        self.model.train()

        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)

        if self.half:
            self.model.half()

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.bad_count)
            if self.bad_count >= 2:
                self.bad_count = 0

            self.optimizer.zero_grad()
            if self.half:
                output = self.model(image.half()).float()
            else:
                output = self.model(image)

            # print(output.shape, target.shape)
            loss = self.criterion(output, target)
            loss.backward()

            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.6f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
            self.writer.add_scalar('train/lr', self.scheduler.lr, epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        if self.half:
            self.model.float()

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.6f Prev best pred: %.6f' % (train_loss, self.best_pred))

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def Dice(self, gt, pred):
        intersection = (pred * gt).sum()
        union = pred.sum() + gt.sum() + 1e-15

        return 2.0 * intersection / union

    def validation(self, epoch):
        self.model.eval()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        dice = 0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.6f' % (test_loss / (i + 1)))

            # output = F.softmax(output, dim=1)
            # #output = F.interpolate(output, size=(1280, 1918), mode='bilinear', align_corners=False)
            # output = output.cpu().numpy()[0][1]
            # output = output > 0.5
            # target = target.cpu().numpy().squeeze()

            target = target.cpu().numpy().squeeze()
            output = output.data.cpu().numpy()
            output = np.argmax(output, axis=1).squeeze()

            # print(output.shape)
            assert target.shape == output.shape

            dice += self.Dice(target, output)
        dice /= tbar.total

        self.writer.add_scalar('val/Dice', dice, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        #print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, Dice: {}".format(Acc, Acc_class, mIoU, FWIoU, Dice))
        print('Loss: %.6f Dice: %.6f' % (test_loss, dice))

        new_pred = dice

        if new_pred <= self.prev_pred:
            self.bad_count += 1
        else:
            self.bad_count = 0
        self.prev_pred = new_pred

        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

        print('Best Pred %.5f' % self.best_pred)
class Predictor(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def _check_dir(self, dir_name):
        if not os.path.exists(dir_name):
            os.mkdir(dir_name)

    def _check_out_dir(self, scales=[]):
        self._check_dir('output')
        for scale in scales:
            fp = 'output/scale_{}'.format(scale)
            self._check_dir(fp)

    def predict(self,
                filename='',
                scales=[0.8, 1, 1.2, 1.4, 1.6, 1.8, 2, 2.2, 2.4]):
        self.model.eval()

        # 检查输出目录
        self._check_out_dir(scales)

        # 读入图片
        image = Image.open(filename)
        # 处理图片
        for scale in scales:
            in_transform = transforms.Compose([
                transforms.ToTensor(), lambda x: torch.unsqueeze(x, dim=0),
                lambda x: torch.nn.Upsample(scale_factor=scale)(x)
            ])
            img = in_transform(image)
            img = img.cuda()

            with torch.no_grad():
                output = self.model(img)
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # predict mask
            mask = pred[0]
            mask = VOCSegmentation.fill_colormap(mask)
            mask = Image.fromarray(mask)

            file_path = 'output/scale_{}/{}.jpg'
            fn = filename.split('/')[-1].replace('.jpg', '')

            mask.save(file_path.format(scale, fn), format='jpeg')
Example #20
0
class Val(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def test(self):

        for i, sample in enumerate(test_ds):
            image = sample['image']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

            prediction = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
            prediction = prediction.astype('uint8')

            for img in range(2):
                predict = prediction[img]
                im = PIL.Image.fromarray(predict)
                im.save(os.path.join('test', 'result', str(i) + str(img) + '.png'))
                mask = im.convert('P')
                new_mask = mask.putpalette(palette)
                new_mask.save(os.path.join('test', 'show', str(i) + str(img) + '.png'))
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        if args.densecrfloss >0:
            self.densecrflosslayer = DenseCRFLoss(weight=args.densecrfloss, sigma_rgb=args.sigma_rgb, sigma_xy=args.sigma_xy, scale_factor=args.rloss_scale)
            print(self.densecrflosslayer)
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_celoss = 0.0
        train_crfloss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        softmax = nn.Softmax(dim=1)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            croppings = (target!=254).float()
            target[target==254]=255
            # Pixels labeled 255 are those unlabeled pixels. Padded region are labeled 254.
            # see function RandomScaleCrop in dataloaders/custom_transforms.py for the detail in data preprocessing
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            
            celoss = self.criterion(output, target)
            
            if self.args.densecrfloss ==0:
                loss = celoss
            else:
                probs = softmax(output)
                denormalized_image = denormalizeimage(sample['image'], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                densecrfloss = self.densecrflosslayer(denormalized_image,probs,croppings)
                if self.args.cuda:
                    densecrfloss = densecrfloss.cuda()
                loss = celoss + densecrfloss
                train_crfloss += densecrfloss.item()
            loss.backward()
        
            self.optimizer.step()
            train_loss += loss.item()
            train_celoss += celoss.item()
            
            tbar.set_description('Train loss: %.3f = CE loss %.3f + CRF loss: %.3f' 
                             % (train_loss / (i + 1),train_celoss / (i + 1),train_crfloss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        #if self.args.no_val:
        if self.args.save_interval:
            # save checkpoint every interval epoch
            is_best = False
            if (epoch + 1) % self.args.save_interval == 0:
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best, filename='checkpoint_epoch_{}.pth.tar'.format(str(epoch+1)))


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            target[target==254]=255
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = F.softmax(output, dim=1)
            pred = output.data.cpu().numpy()

            if self.args.post_process:
                pool = mp.Pool(mp.cpu_count())
                image = image.data.cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
                pred = pool.map(dense_crf_wrapper, zip(image, pred))
                pool.close()

            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #22
0
def main():

    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()
    best_pred = 0

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    maml = Meta(args, criterion).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logging.info(maml)
    logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size)
    mini_valid = MiniImagenet(args.data_path,
                              mode='train',
                              n_way=args.n_way,
                              k_shot=args.k_spt,
                              k_query=args.k_qry,
                              batch_size=args.test_batch_size,
                              resize=args.img_size)

    train_loader = DataLoader(mini,
                              args.meta_batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    valid_loader = DataLoader(mini_valid,
                              args.meta_test_batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)

    for epoch in range(args.epoch):
        # fetch batch_size num of episode each time
        logging.info('--------- Epoch: {} ----------'.format(epoch))

        train_accs_theta, train_accs_w = meta_train(train_loader, maml, device,
                                                    epoch, writer)
        logging.info(
            '[Epoch: {}]\t Train acc_theta: {}\t Train acc_w: {}'.format(
                epoch, train_accs_theta, train_accs_w))
        test_accs_theta, test_accs_w = meta_test(valid_loader, maml, device,
                                                 epoch, writer)
        logging.info(
            '[Epoch: {}]\t Test acc_theta: {}\t Test acc_w: {}'.format(
                epoch, test_accs_theta, test_accs_w))

        genotype = maml.model.genotype()
        logging.info('genotype = %s', genotype)

        logging.info(F.softmax(maml.model.alphas_normal, dim=-1))
        logging.info(F.softmax(maml.model.alphas_reduce, dim=-1))

        # Save the best meta model.
        new_pred = test_accs_w[-1]
        if new_pred > best_pred:
            is_best = True
            best_pred = new_pred
        else:
            is_best = False
        saver.save_checkpoint(
            {
                'epoch':
                epoch,
                'state_dict_w':
                maml.module.state_dict()
                if isinstance(maml, nn.DataParallel) else maml.state_dict(),
                'state_dict_theta':
                maml.model.arch_parameters(),
                'best_pred':
                best_pred,
            }, is_best)
Example #23
0
class trainNew(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(num_classes=self.nclass,
                            num_layers=12,
                            criterion=self.criterion,
                            filter_multiplier=self.args.filter_multiplier)
        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loader),
                                      min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #24
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network G and D (the output of Deeplab is score for each class, and the score is
        # passed through softmax layer before going into PatchGAN)
        #================================== network ==============================================#
        network_G = DeepLab(num_classes=self.nclass,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)

        softmax_layer = torch.nn.Softmax(dim=1)

        network_D = networks.define_D(24,
                                      64,
                                      netD='basic',
                                      n_layers_D=3,
                                      norm='batch',
                                      init_type='normal',
                                      init_gain=0.02,
                                      gpu_ids=self.args.gpu_ids)
        #=========================================================================================#

        train_params = [{
            'params': network_G.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': network_G.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        #================================== network ==============================================#
        optimizer_G = torch.optim.SGD(train_params,
                                      momentum=args.momentum,
                                      weight_decay=args.weight_decay,
                                      nesterov=args.nesterov)
        optimizer_D = torch.optim.Adam(network_D.parameters(),
                                       lr=0.0002,
                                       betas=(0.5, 0.999))
        #=========================================================================================#

        # Define whether to use class balanced weights for criterion
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        #=================== GAN criterion and Segmentation criterion ======================================#
        self.criterionGAN = networks.GANLoss('vanilla').to(
            args.gpu_ids[0])  ### set device manually
        self.criterionSeg = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        #===================================================================================================#

        self.network_G, self.softmax_layer, self.network_D = network_G, softmax_layer, network_D
        self.optimizer_G, self.optimizer_D = optimizer_G, optimizer_D

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.network_G = torch.nn.DataParallel(
                self.network_G, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.network_G)
            self.network_G = self.network_G.cuda()

        #====================== no resume ===================================================================#
        # Resuming checkpoint
        self.best_pred = 0.0
        #        if args.resume is not None:
        #            if not os.path.isfile(args.resume):
        #                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
        #            checkpoint = torch.load(args.resume)
        #            args.start_epoch = checkpoint['epoch']
        #            if args.cuda:
        #                self.network_G.module.load_state_dict(checkpoint['state_dict'])
        #            else:
        #                self.network_G.load_state_dict(checkpoint['state_dict'])
        #            if not args.ft:
        #                self.optimizer.load_state_dict(checkpoint['optimizer'])
        #            self.best_pred = checkpoint['best_pred']
        #            print("=> loaded checkpoint '{}' (epoch {})"
        #                  .format(args.resume, checkpoint['epoch']))
        #=======================================================================================================#

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        G_Seg_loss = 0.0
        G_GAN_loss = 0.0
        D_fake_loss = 0.0
        D_real_loss = 0.0

        #======================== train mode to set batch normalization =======================================#
        self.network_G.train()
        self.network_D.train()
        #======================================================================================================#

        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer_G, i, epoch,
                           self.best_pred)  # tune learning rate

            #================================= GAN training process (pix2pix) ============================================#

            # prepare tensors
            output_score = self.network_G(
                image)  # score map for each class in pixels
            output = self.softmax_layer(output_score)  # label for each pixel

            target_one_hot = self.make_one_hot(
                target,
                C=21)  # change target to one-hot coding to feed into PatchGAN

            fake_AB = torch.cat((image, output), 1)
            real_AB = torch.cat((image, target_one_hot), 1)

            # ================================================================== #
            #                      Train the discriminator                       #
            # ================================================================== #

            # freeze G, unfreese D
            self.set_requires_grad(self.network_G, False)
            self.set_requires_grad(self.softmax_layer, False)
            self.set_requires_grad(self.network_D, True)

            # reset D grad
            self.optimizer_D.zero_grad()

            # fake input
            pred_fake = self.network_D(fake_AB.detach())
            loss_D_fake = self.criterionGAN(pred_fake, False)

            # real input
            pred_real = self.network_D(real_AB)
            loss_D_real = self.criterionGAN(pred_real, True)

            # combine loss and calculate gradients
            loss_D = (loss_D_fake + loss_D_real) / (2.0 * self.args.batch_size)

            loss_D.backward()
            self.optimizer_D.step()

            # ================================================================== #
            #                        Train the generator                         #
            # ================================================================== #

            # unfreeze G, freese D
            self.set_requires_grad(self.network_G, True)
            self.set_requires_grad(self.softmax_layer, True)
            self.set_requires_grad(self.network_D, False)

            # reset G grad
            self.optimizer_G.zero_grad()

            # fake input should let D predict 1
            pred_fake = self.network_D(fake_AB)
            loss_G_GAN = self.criterionGAN(pred_fake, True)

            # Segmentation loss G(A) = B
            loss_G_CE = self.criterionSeg(
                output_score, target
            ) * self.args.lambda_Seg  # 1.0 is lambda_CE (weight for cross entropy loss)

            # combine loss and calculate gradients
            # lambda = 0.1
            loss_G = loss_G_GAN * self.args.lambda_GAN / self.args.batch_size + loss_G_CE
            loss_G.backward()

            self.optimizer_G.step()

            # display G and D loss
            G_Seg_loss += loss_G_CE.item()
            G_GAN_loss += loss_G_GAN.item()
            D_fake_loss += loss_D_fake.item()
            D_real_loss += loss_D_real.item()

            #===================================================================================================#

            tbar.set_description(
                'G_Seg_loss: %.3f G_GAN_los: %.3f D_fake_loss: %.3f D_real_loss: %.3f'
                % (G_Seg_loss / (i + 1), G_GAN_loss / (i + 1), D_fake_loss /
                   (i + 1), D_real_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss_G_CE.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', G_Seg_loss, epoch)
        print('Training:')
        print('    [Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('    Train G_Seg_Loss: %.3f' % G_Seg_loss)


#======================================= no save checkpoint ==================#
#        if self.args.no_val:
#            # save checkpoint every epoch
#            is_best = False
#            self.saver.save_checkpoint({
#                'epoch': epoch + 1,
#                'state_dict': self.model.module.state_dict(),
#                'optimizer': self.optimizer.state_dict(),
#                'best_pred': self.best_pred,
#            }, is_best)
#=============================================================================#

    def validation(self, epoch):
        self.network_G.eval()
        self.evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.network_G(image)
            loss = self.criterionSeg(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('    [Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("    Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('    Test G_Seg_Loss: %.3f' % test_loss)

        new_pred = mIoU

        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred

            #============== only save checkpoint for best model ======================#
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict_G': self.network_G.module.state_dict(),
                    'state_dict_D': self.network_D.state_dict(),
                    'optimizer_G': self.optimizer_G.state_dict(),
                    'optimizer_D': self.optimizer_D.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
            #=======================================================#

    #========================== new method ===============================#
    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def make_one_hot(self, labels, C=21):
        labels[labels == 255] = 0.0
        labels = labels.unsqueeze(1)

        one_hot = torch.cuda.FloatTensor(labels.size(0),
                                         C,
                                         labels.size(2),
                                         labels.size(3),
                                         device=labels.device).zero_()
        target = one_hot.scatter_(1, labels.long(), 1.0)

        return target
Example #25
0
class trainNew(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        cell_path = os.path.join(args.saved_arch_path, 'genotype.npy')
        network_path_space = os.path.join(args.saved_arch_path,
                                          'network_path_space.npy')

        new_cell_arch = np.load(cell_path)
        new_network_arch = np.load(network_path_space)

        # Define network
        model = newModel(network_arch=new_network_arch,
                         cell_arch=new_cell_arch,
                         num_classes=self.nclass,
                         num_layers=12)
        #                        output_stride=args.out_stride,
        #                        sync_bn=args.sync_bn,
        #                        freeze_bn=args.freeze_bn)
        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
        #        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        #
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(
            args.lr_scheduler, args.lr, args.epochs,
            len(self.train_loader))  #TODO: use min_lr ?

        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #26
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        # 使用tensorboardX可视化
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        #        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        kwargs = {'num_workers': 0, 'pin_memory': True}
        #self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        self.train_loader, self.val_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        # Define Criterion
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        if args.model_name == 'unet':
            #model = UNet_ac(args.n_channels, args.n_filters, args.n_class).cuda()
            model = UNet_SNws(args.n_channels, args.n_filters, args.n_class,
                              args.using_movavg, args.using_bn).cuda()
            model = UNet_bn(args.n_channels, args.n_filters,
                            args.n_class).cuda()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
            #optimizer = torch.optim.AdamW(model.parameters(), lr=args.arch_lr, betas=(0.9, 0.999), weight_decay=args.weight_decay)
            #optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
        # elif args.model_name == 'hunet':
        #     model = HUNet(args.n_channels, args.n_filters, args.n_class, args.using_movavg, args.using_bn).cuda()
        #     optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.model_name == 'unet3+':
            model = UNet3p_SNws(args.n_channels, args.n_filters, args.n_class,
                                args.using_movavg, args.using_bn).cuda()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        elif args.model_name == 'unet3+_aspp':
            #model = UNet3p_aspp(args.n_channels, args.n_filters, args.n_class, args.using_movavg, args.using_bn).cuda()
            #model = UNet3p_aspp_SNws(args.n_channels, args.n_filters, args.n_class, args.using_movavg, args.using_bn).cuda()
            #model = UNet3p_res_aspp_SNws(args.n_channels, args.n_filters, args.n_class, args.using_movavg, args.using_bn).cuda()
            model = UNet3p_res_edge_aspp_SNws(args.n_channels, args.n_filters,
                                              args.n_class, args.using_movavg,
                                              args.using_bn).cuda()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        elif args.model_name == 'unet3+_ocr':
            model = UNet3p_res_ocr_SNws(args.n_channels, args.n_filters,
                                        args.n_class, args.using_movavg,
                                        args.using_bn).cuda()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        # elif args.model_name == 'unet3+_resnest_aspp':
        #     model = UNet3p_resnest_aspp(args.n_channels, args.n_filters, args.n_class, args.using_movavg, args.using_bn).cuda()
        #     optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.model_name == 'gscnn':
            model = GSCNN(args.n_channels, args.n_filters, args.n_class,
                          args.using_movavg, args.using_bn).cuda()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        elif args.model_name == 'pspnet':
            model = PSPNet(args.n_channels, args.n_filters,
                           args.n_class).cuda()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        elif args.model_name == 'segnet':
            model = Segnet(args.n_channels, args.n_filters,
                           args.n_class).cuda()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        elif args.model_name == 'hrnet':
            MODEL = {
                'ALIGN_CORNERS': True,
                'EXTRA': {
                    'FINAL_CONV_KERNEL':
                    1,  # EXTRA 具体定义了模型的结果,包括 4 个 STAGE,各自的参数
                    'STAGE1': {
                        'NUM_MODULES': 1,  # HighResolutionModule 重复次数
                        'NUM_BRANCHES': 1,  # 分支数
                        'BLOCK': 'BOTTLENECK',
                        'NUM_BLOCKS': 4,
                        'NUM_CHANNELS': 64,
                        'FUSE_METHOD': 'SUM'
                    },
                    'STAGE2': {
                        'NUM_MODULES': 1,
                        'NUM_BRANCHES': 2,
                        'BLOCK': 'BASIC',
                        'NUM_BLOCKS': [4, 4],
                        'NUM_CHANNELS': [48, 96],
                        'FUSE_METHOD': 'SUM'
                    },
                    'STAGE3': {
                        'NUM_MODULES': 4,
                        'NUM_BRANCHES': 3,
                        'BLOCK': 'BASIC',
                        'NUM_BLOCKS': [4, 4, 4],
                        'NUM_CHANNELS': [48, 96, 192],
                        'FUSE_METHOD': 'SUM'
                    },
                    'STAGE4': {
                        'NUM_MODULES': 3,
                        'NUM_BRANCHES': 4,
                        'BLOCK': 'BASIC',
                        'NUM_BLOCKS': [4, 4, 4, 4],
                        'NUM_CHANNELS': [48, 96, 192, 384],
                        'FUSE_METHOD': 'SUM'
                    }
                }
            }
            model = HighResolutionNet(args.n_channels, args.n_filters,
                                      args.n_class, MODEL).cuda()
            # model.init_weights()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        elif args.model_name == 'hrnet+ocr':
            MODEL = {
                'ALIGN_CORNERS': True,
                'EXTRA': {
                    'FINAL_CONV_KERNEL':
                    1,  # EXTRA 具体定义了模型的结果,包括 4 个 STAGE,各自的参数
                    'STAGE1': {
                        'NUM_MODULES': 1,  # HighResolutionModule 重复次数
                        'NUM_BRANCHES': 1,  # 分支数
                        'BLOCK': 'BOTTLENECK',
                        'NUM_BLOCKS': 4,
                        'NUM_CHANNELS': 64,
                        'FUSE_METHOD': 'SUM'
                    },
                    'STAGE2': {
                        'NUM_MODULES': 1,
                        'NUM_BRANCHES': 2,
                        'BLOCK': 'BASIC',
                        'NUM_BLOCKS': [4, 4],
                        'NUM_CHANNELS': [48, 96],
                        'FUSE_METHOD': 'SUM'
                    },
                    'STAGE3': {
                        'NUM_MODULES': 4,
                        'NUM_BRANCHES': 3,
                        'BLOCK': 'BASIC',
                        'NUM_BLOCKS': [4, 4, 4],
                        'NUM_CHANNELS': [48, 96, 192],
                        'FUSE_METHOD': 'SUM'
                    },
                    'STAGE4': {
                        'NUM_MODULES': 3,
                        'NUM_BRANCHES': 4,
                        'BLOCK': 'BASIC',
                        'NUM_BLOCKS': [4, 4, 4, 4],
                        'NUM_CHANNELS': [48, 96, 192, 384],
                        'FUSE_METHOD': 'SUM'
                    }
                }
            }
            # model = HighResolutionNet_OCR(args.n_channels, args.n_filters, args.n_class, MODEL).cuda()
            model = HighResolutionNet_OCR_SNws(args.n_channels, args.n_filters,
                                               args.n_class, MODEL).cuda()
            # model.init_weights()
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        elif args.model_name == 'deeplabv3+':
            # Define network
            model = DeepLab(num_classes=self.nclass,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)
            backbone = model.backbone

            #             backbone.conv1 = nn.Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            print('change the input channels', backbone.conv1)

            train_params = [{
                'params': model.get_1x_lr_params(),
                'lr': args.lr
            }, {
                'params': model.get_10x_lr_params(),
                'lr': args.lr * 10
            }]

            #optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
            #                            weight_decay=args.weight_decay, nesterov=args.nesterov)
            optimizer = torch.optim.AdamW(train_params,
                                          weight_decay=args.weight_decay)
            #optimizer = torch.optim.AdamW(train_params, lr=args.arch_lr, betas=(0.9, 0.999), weight_decay=args.weight_decay)
        #elif args.model_name == 'autodeeplab':
        #model = AutoDeeplab(args.n_class, 12, self.criterion, crop_size=args.crop_size)
        #optimizer = torch.optim.AdamW(model.weight_parameters(), lr=args.lr, weight_decay=args.weight_decay)
        #optimizer = torch.optim.SGD(model.weight_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        #self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                                    args.epochs, len(self.train_loader))
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            print(self.args.gpu_ids)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            # print(image.shape)
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)

            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch


#                 self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss,
                               epoch)  # 保存标量值
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % (train_loss / len(tbar)))
        # print('Loss: %.3f' % (train_loss / i))

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()  # 创建全为0的混淆矩阵
        tbar = tqdm(self.val_loader, desc='\r')  # 回车符
        val_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            #            image, target = sample[0], sample[1]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            val_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (val_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)  # 按行
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union(
        )  # 频权交并比
        self.writer.add_scalar('val/total_loss_epoch', val_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % (val_loss / len(tbar)))

        new_pred = FWIoU  # mIoU

        # log
        logfile = os.path.join('/home/wzj/mine_cloud_14/', 'log.txt')
        log_file = open(logfile, 'a')
        if epoch == 0:
            log_file.seek(0)
            log_file.truncate()
            log_file.write(self.args.model_name + '\n')
        log_file.write('Epoch: %d, ' % (epoch + 1))
        if new_pred < self.best_pred:
            log_file.write(
                'Acc: {}, Acc_class: {}, mIoU: {}, fwIoU: {}, best_fwIoU: {}, '
                .format(Acc, Acc_class, mIoU, FWIoU, self.best_pred))
        else:
            log_file.write(
                'Acc: {}, Acc_class: {}, mIoU: {}, fwIoU: {}, best_fwIoU: {}, '
                .format(Acc, Acc_class, mIoU, FWIoU, new_pred))
        log_file.write('Loss: %.3f\n' % (val_loss / len(tbar)))
        if epoch == 199:  # 499
            log_file.close()

        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #27
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.logger = self.saver.create_logger()

        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        self.model = EDCNet(args.rgb_dim, args.event_dim, num_classes=self.nclass, use_bn=True)
        train_params = [{'params': self.model.random_init_params(),
                         'lr': 10*args.lr, 'weight_decay': 10*args.weight_decay},
                        {'params': self.model.fine_tune_params(),
                         'lr': args.lr, 'weight_decay': args.weight_decay}]
        self.optimizer = torch.optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay)
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.to(self.args.device)
        if args.use_balanced_weights:
            root_dir = Path.db_root_dir(args.dataset)[0] if isinstance(Path.db_root_dir(args.dataset), list) else Path.db_root_dir(args.dataset)
            classes_weights_path = os.path.join(root_dir,
                                                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass, classes_weights_path)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion_event = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='event')
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader), warmup_epochs=5)

        self.evaluator = Evaluator(self.nclass, self.logger)
        self.saver.save_model_summary(self.model)
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            target = sample['label']
            image = sample['image']
            event = sample['event']
            if self.args.cuda:
                target = target.to(self.args.device)
                image = image.to(self.args.device)
                event = event.to(self.args.device)
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output, output_event = self.model(image)
            loss = self.criterion(output, target)
            loss_event = self.criterion_event(output_event, event)
            loss += (loss_event * 0.1)
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

        self.writer.add_scalar('train/total_loss_epoch', train_loss/num_img_tr, epoch)
        self.logger.info('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + target.data.shape[0]))
        self.logger.info('Loss: %.3f' % (train_loss/num_img_tr))

        if self.args.no_val:
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        num_img_val = len(self.val_loader)
        for i, (sample, _) in enumerate(tbar):
            target = sample['label']
            image = sample['image']
            event = sample['event']
            if self.args.cuda:
                target = target.to(self.args.device)
                image = image.to(self.args.device)
                event = event.to(self.args.device)
            with torch.no_grad():
                output, output_event = self.model(image)
            loss = self.criterion(output, target)
            loss_event = self.criterion_event(output_event, event)
            loss += (loss_event * 20)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss/len(self.val_loader), epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        self.logger.info('Validation:')
        self.logger.info('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + target.data.shape[0]))
        self.logger.info("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        self.logger.info('Loss: %.3f' % (test_loss/num_img_val))

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #28
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        # denormalize for detph image
        self.mean_depth = torch.as_tensor(0.12176,dtype=torch.float32,device='cpu')
        self.std_depth = torch.as_tensor(0.09752,dtype=torch.float32,device='cpu')
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        # Define network
        resnet = resnet18(pretrained=True, efficient=False, use_bn= True)
        model = RFNet(resnet, num_classes=self.nclass, use_bn=True)

        train_params = [{'params': model.random_init_params()},
                        {'params': model.fine_tune_params(), 'lr': args.lr, 'weight_decay':args.weight_decay}]
        # Define Optimizer
        optimizer = torch.optim.Adam(train_params, lr=args.lr * 4,
                                    weight_decay=args.weight_decay * 4)
        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        # Define loss function
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader))
        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            if self.args.depth:
                image, depth, target = sample['image'], sample['depth'], sample['label']
            else:
                image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                if self.args.depth:
                    depth = depth.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            if self.args.depth:
                output = self.model(image, depth)
            else:
                output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.depth:
                    self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

                    depth_display=depth[0].cpu().unsqueeze(0)
                    depth_display = depth_display.mul_(self.std_depth).add_(self.mean_depth)
                    depth_display = depth_display.numpy()
                    depth_display = depth_display*255
                    depth_display = depth_display.astype(np.uint8)
                    self.writer.add_image('Depth', depth_display, global_step)

                else:
                    self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, (sample, img_path) in enumerate(tbar):
            if self.args.depth:
                image, depth, target = sample['image'], sample['depth'], sample['label']
            else:
                image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                if self.args.depth:
                    depth = depth.cuda()
            with torch.no_grad():
                if self.args.depth:
                    output = self.model(image, depth)
                else:
                    output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Tester(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        # init D
        model_D = FCDiscriminator(num_classes=19)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        optimizer_D = torch.optim.Adam(model_D.parameters(),
                                       lr=1e-4,
                                       betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\' + args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        self.model, self.optimizer = model, optimizer
        self.model_D, self.optimizer_D = model_D, optimizer_D

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model_D = torch.nn.DataParallel(self.model_D,
                                                 device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            patch_replication_callback(self.model_D)
            self.model = self.model.cuda()
            self.model_D = self.model_D.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def imgsaver(self, img, imgname):
        im1 = np.uint8(img.transpose(1, 2, 0)).squeeze()
        #filename_list = sorted(os.listdir(self.args.test_img_root))

        valid_classes = [
            7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31,
            32, 33
        ]
        class_map = dict(zip(range(19), valid_classes))
        im1_np = np.uint8(np.zeros([512, 512]))
        for _validc in range(19):
            im1_np[im1 == _validc] = class_map[_validc]
        saveim1 = Image.fromarray(im1_np, mode='L')
        saveim1 = saveim1.resize((1280, 640), Image.NEAREST)
        saveim1.save('result/' + imgname)

        palette = [[128, 64, 128], [244, 35, 232], [70, 70, 70],
                   [102, 102, 156], [190, 153, 153], [153, 153, 153],
                   [250, 170, 30], [220, 220, 0], [107, 142, 35],
                   [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
                   [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
                   [0, 0, 230], [119, 11, 32]]
        #[0,0,0]]
        class_color_map = dict(zip(range(19), palette))
        im2_np = np.uint8(np.zeros([512, 512, 3]))
        for _validc in range(19):
            im2_np[im1 == _validc] = class_color_map[_validc]
        saveim2 = Image.fromarray(im2_np)
        saveim2 = saveim2.resize((1280, 640), Image.NEAREST)
        saveim2.save('result/' + imgname[:-4] + '_color.png')
        # print('saving: '+filename_list[idx])

    def test(self, epoch):
        self.model.eval()
        tbar = tqdm(self.test_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image = sample['image']
            if self.args.cuda:
                image = image.cuda()
            with torch.no_grad():
                output = self.model(image)
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.imgsaver(pred, sample['name'][0])

        # Fast test during the training
        print('Test:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.test_batch_size + image.data.shape[0]))
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #31
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        if args.distributed:
            if args.local_rank == 0:
                self.saver = Saver(args)
        else:
            self.saver = Saver(args)
            self.saver.save_experiment_config()
            # Define Tensorboard Summary
            self.summary = TensorboardSummary(self.saver.experiment_dir)
            self.writer = self.summary.create_summary()

        # PATH = args.path
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.val_loader, self.nclass = make_data_loader(args, **kwargs)
        # self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = SCNN(nclass=self.nclass, backbone=args.backbone, output_stride=args.out_stride, cuda=args.cuda,
                     extension=args.ext)

        # Define Optimizer
        # optimizer = torch.optim.SGD(model.parameters(),args.lr, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)

        # model, optimizer = amp.initialize(model,optimizer,opt_level="O1")

        # Define Criterion
        weight = None
        # criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = FocalLoss(gamma=0, alpha=[0.2, 0.98], img_size=512*512)
        self.criterion1 = FocalLoss(gamma=5, alpha=[0.2, 0.98], img_size=512 * 512)
        self.criterion2 = disc_loss(delta_v=0.5, delta_d=3.0, param_var=1.0, param_dist=1.0,
                                    param_reg=0.001, EMBEDDING_FEATS_DIMS=21, image_shape=[512, 512])

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.val_loader), local_rank=args.local_rank)

        # Using cuda
        # if args.cuda:
        self.model = self.model.cuda()
        # if args.distributed:
        # self.model = DistributedDataParallel(self.model)
        # self.model = torch.nn.DataParallel(self.model)
            # patch_replication_callback(self.model)

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            filename = 'checkpoint.pth.tar'
            args.resume = os.path.join(args.ckpt_dir, filename)
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # if args.cuda:


            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict'].items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v

            checkpoint['state_dict'] = new_state_dict

            self.model.load_state_dict(checkpoint['state_dict'])
            # else:
            # self.model.load_state_dict(checkpoint['state_dict'])
            # if not args.ft:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        max_instances = 1
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target, ins_target = sample['image'], sample['bin_label'], sample['label']
            # _target = target.cpu().numpy()
            # if np.max(_target) > max_instances:
            #     max_instances = np.max(_target)
            #     print(max_instances)

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)

            # if i % 10==0:
            #     misc.imsave('/mfc/user/1623600/.temp6/train_{:s}_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),np.transpose(image[0].cpu().numpy(),(1,2,0)))
            #     os.chmod('/mfc/user/1623600/.temp6/train_{:s}_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),0o777)

            # self.criterion = DataParallelCriterion(self.criterion)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, ins_target)

            reg_lambda = 0.01

            loss = loss1 + 10 * loss2
            # loss = loss1
            output = output[1]
            # loss.back
            # with amp.scale_loss(loss, self.optimizer) as scaled_loss:
            #     scaled_loss.backward()

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            if self.args.distributed:
                if self.args.local_rank == 0:
                    self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            else:
                self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr / 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.distributed:
                    if self.args.local_rank == 0:
                        self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
                else:
                    self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        if self.args.distributed:
            if self.args.local_rank == 0:
                self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        else:
            self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)

        if self.args.local_rank == 0:
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print('Loss: %.3f' % train_loss)

        if self.args.distributed:
            if self.args.local_rank == 0:
                if self.args.no_val:
                    # save checkpoint every epoch
                    is_best = False
                    self.saver.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)
            else:
                if self.args.no_val:
                    # save checkpoint every epoch
                    is_best = False
                    self.saver.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)

    def validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        destination_path = os.path.join(self.args.path,'seg_lane')
        if not os.path.isdir(destination_path):
            os.mkdir(destination_path,0o777)

        postprocessor = LanePostprocess.LaneNetPostProcessor()

        aa_sequence = mvpuai.MSequence()

        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, lbl_path, resized_img = sample['image'], sample['lbl_path'], sample['resized_img']



            img = [image[0][...,_ind*512: (_ind+1)*512] for _ind in range(4)]
            img = np.stack(img+[resized_img[0]],axis=0)
            img = torch.from_numpy(img)
            if self.args.cuda:
                img = img.cuda()
                image = image.cuda()
            with torch.no_grad():
                output = self.model(img)

            pred = output[1]

            _upsampled=torch.nn.Upsample(size=[512,2048])
            overall_pred = pred[4,...].view([1,2,512,512])
            _upsampled=_upsampled(overall_pred)

            upsampled_final = torch.zeros(2,1024,2048)
            upsampled_final[:,512:,:512] = pred[0,...]
            upsampled_final[:, 512:, 512:1024] = pred[1, ...]
            upsampled_final[:, 512:, 1024:1024+512] = pred[2, ...]
            upsampled_final[:, 512:, 1024 + 512:2048] = pred[3, ...]
            upsampled_final = upsampled_final.view([1, 2, 1024, 2048])
            upsampled_final[..., 512:, :] = _upsampled
            upsampled_final = np.argmax(upsampled_final, axis=1)
            pred = upsampled_final.data.cpu().numpy()



            instance_seg = output[0]

            _upsampled_instance = torch.nn.Upsample(size=[512, 2048])
            overall_pred = instance_seg[4, ...].view([1, 21, 512, 512])
            _upsampled_instance = _upsampled_instance(overall_pred)

            upsampled_final_instance = torch.zeros(21, 1024, 2048)
            upsampled_final_instance[:, 512:, :512] = instance_seg[0, ...]
            upsampled_final_instance[:, 512:, 512:1024] = instance_seg[1, ...]
            upsampled_final_instance[:, 512:, 1024:1024 + 512] = instance_seg[2, ...]
            upsampled_final_instance[:, 512:, 1024 + 512:2048] = instance_seg[3, ...]
            upsampled_final_instance= upsampled_final_instance.view([1, 21, 1024, 2048])
            upsampled_final_instance[..., 512:, :] = _upsampled_instance

            instance_seg = upsampled_final_instance.data.cpu().numpy()
            # instance_seg = np.argmax(upsampled_final_instance, axis=1)






            # Add batch sample into evaluator

            # if i % 100 == 0:
            resized_img = np.squeeze(resized_img)
            pred = np.squeeze(pred)
            instance_seg = np.squeeze(instance_seg)

            # resized_img = np.transpose(resized_img.cpu().numpy(), (1, 2, 0))

            instance_seg = np.transpose(instance_seg, (1, 2, 0))

            postprocess_result = postprocessor.postprocess(
                binary_seg_result=pred,
                instance_seg_result=instance_seg,
                source_image=image
            )


            image = self.de_normalize(np.transpose(image[0].cpu().numpy(),(1,2,0)))

            # misc.imsave(destination_path + '/' + lbl_path[0],
            #             np.transpose(image.cpu().numpy(), (1, 2, 0)) + 3 * np.asarray(
            #                 np.stack((pred, pred, pred), axis=-1), dtype=np.uint8))

            show_source_image = np.zeros((1024, 2048, 3))

            show_source_image[512:, ...] = image
            image = show_source_image


            predicted_lanes = postprocess_result['lane_pts']

            # predicted_lanes = predicted_lanes[...,0]
            # bsp_lanes = []
            predicted_lanes = [np.asarray(pred_lane) for pred_lane in predicted_lanes]



            tensor_curvepts, tensor_cpts =inference(bsplineMat=predicted_lanes,i=i)

            tmp_mask = np.zeros(shape=(image.shape[0], image.shape[1]), dtype=np.uint8)
            src_lane_pts = np.asarray(tensor_curvepts)
            for lane_index, coords in enumerate(src_lane_pts):
                tmp_mask[tuple((np.int_(coords[:, 1]), np.int_(coords[:, 0])))] = lane_index + 1

            bsppts_mask = np.stack((tmp_mask, tmp_mask, tmp_mask), axis=-1)


            # misc.imsave(destination_path + '/mask_' + lbl_path[0],
            #             postprocess_result['mobis_mask_image'])
            # misc.imsave(destination_path + '/' + lbl_path[0],
            #             50*postprocess_result['mask_image']+50*postprocess_result['lanepts_mask'])
            misc.imsave(destination_path + '/' + lbl_path[0],
                                    postprocess_result['mobis_mask_image'])
            try:
                os.chmod(destination_path + '/'+ lbl_path[0],0o777)
            except:
                pass
            aa_sequence.add_frame(MFrame(i))
            for idx in range(tensor_cpts.shape[1]):
                _Obj = mvpuai.get_object_by_name(MString.Frame.Object.Type.LANE)
                _Obj.subclass_id = 1
                _Obj.instance_id = idx
                _list = []
                for ind in range(10):
                    _list.append(Point(int(tensor_cpts[0,idx,ind]), int(tensor_cpts[1,idx,ind])))

                _ctrl_pts = list([point.x, point.y] for point in _list)
                # b_spline = BSpline.Curve()
                # b_spline.degree = 4
                # b_spline.set_ctrlpts(_ctrl_pts)
                #
                # b_spline.knotvector = utilities.generate_knot_vector(b_spline.degree, len(_ctrl_pts))
                # b_spline.delta = 0.001
                # b_spline.evaluate()

                _cpts = []
                for _cpt in _ctrl_pts:
                    _cpts.append(_cpt[0])
                    _cpts.append(_cpt[1])

                _Obj.b_spline = _cpts


                aa_sequence.frame_list[-1].add_object(_Obj)
                # .add_frame(MFrame(0))




        self.write_json(aa_sequence)



    def de_normalize(self,img,mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):


        # img = np.array(img).astype(np.float32)
        img *= std
        img += mean

        img *= 255.0


        return img

    def write_json(self,aa_sequence):
        output_file_path = os.path.join(self.args.path,'json') + '/annotation_bs.json'
        mvpuai.write_json(output_file_path, aa_sequence)

        try:
            os.chmod(output_file_path, 0o777)
        except :
            pass