Пример #1
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
Пример #2
0
class buildModel(object):
    def __init__(self, para):
        self.args = para

        # Define Saver
        self.saver = Saver(para)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataloader(
            para)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=para.backbone,
                        output_stride=para.out_stride,
                        sync_bn=para.sync_bn,
                        freeze_bn=para.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': para.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': para.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=para.momentum,
                                    weight_decay=para.weight_decay,
                                    nesterov=para.nesterov)

        # Define Criterion

        self.criterion = SegmentationLosses(
            weight=None, cuda=True).build_loss(mode=para.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(para.lr_scheduler, para.lr, para.epochs,
                                      len(self.train_loader))

        self.model = torch.nn.DataParallel(self.model)
        patch_replication_callback(self.model)
        self.model = self.model.cuda()
        # Resuming checkpoint
        self.best_pred = 0.0
Пример #3
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.train_dir = './data_list/train_lite.csv'
        self.train_list = pd.read_csv(self.train_dir)
        self.val_dir = './data_list/val_lite.csv'
        self.val_list = pd.read_csv(self.val_dir)
        self.train_length = len(self.train_list)
        self.val_length = len(self.val_list)
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # 方式2
        self.train_gen, self.val_gen, self.test_gen, self.nclass = make_data_loader2(args)
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)
        # optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)

        # Define Criterion
        # self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion1 = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='ce')
        self.criterion2= SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='dice')

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, self.train_length)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                # self.model.module.load_state_dict(checkpoint['state_dict'])
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        prev_time = time.time()
        self.model.train()
        self.evaluator.reset()

        num_img_tr = self.train_length / self.args.batch_size

        for iteration in range(int(num_img_tr)):
            samples = next(self.train_gen)
            image, target = samples['image'], samples['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, iteration, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, make_one_hot(target.long(), num_classes=self.nclass))
            loss = loss1 + loss2
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            self.writer.add_scalar('train/total_loss_iter', loss.item(), iteration + num_img_tr * epoch)


            # print log  默认log_iters = 4
            if iteration % 4 == 0:
                end_time = time.time()
                print("Iter - %d: train loss: %.3f, celoss: %.4f, diceloss: %.4f, time cost: %.3f s" \
                      % (iteration, loss.item(), loss1.item(), loss2.item(), end_time - prev_time))
                prev_time = time.time()

            # Show 10 * 3 inference results each epoch
            if iteration % (num_img_tr // 10) == 0:
                global_step = iteration + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        print("input image shape/iter:", image.shape)

        # train evaluate
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU = np.nanmean(IoU)
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc_tr:{}, Acc_class_tr:{}, IoU_tr:{}, mIoU_tr:{}, fwIoU_tr: {}".format(Acc, Acc_class, IoU, mIoU, FWIoU))

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)





    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        val_loss = 0.0
        prev_time = time.time()
        num_img_val = self.val_length / self.args.batch_size
        print("Validation:","epoch ", epoch)
        print(num_img_val)
        for iteration in range(int(num_img_val)):
            samples = next(self.val_gen)
            image, target = samples['image'], samples['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():  #
                output = self.model(image)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, make_one_hot(target.long(), num_classes=self.nclass))
            loss = loss1 + loss2
            val_loss += loss.item()
            self.writer.add_scalar('val/total_loss_iter', loss.item(), iteration + num_img_val * epoch)
            val_loss += loss.item()

            # print log  默认log_iters = 4
            if iteration % 4 == 0:
                end_time = time.time()
                print("Iter - %d: validation loss: %.3f, celoss: %.4f, diceloss: %.4f, time cost: %.3f s" \
                      % (iteration, loss.item(), loss1.item(), loss2.item(), end_time - prev_time))
                prev_time = time.time()


            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        print(image.shape)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU = np.nanmean(IoU)
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', val_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print("Acc_val:{}, Acc_class_val:{}, IoU:val:{}, mIoU_val:{}, fwIoU_val: {}".format(Acc, Acc_class, IoU, mIoU, FWIoU))
        print('Loss: %.3f' % val_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #4
0
class trainNew(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        cell_path = os.path.join(args.saved_arch_path, 'genotype.npy')
        network_path_space = os.path.join(args.saved_arch_path,
                                          'network_path_space.npy')

        new_cell_arch = np.load(cell_path)
        new_network_arch = np.load(network_path_space)

        # Define network
        model = newModel(network_arch=new_network_arch,
                         cell_arch=new_cell_arch,
                         num_classes=self.nclass,
                         num_layers=12)
        #                        output_stride=args.out_stride,
        #                        sync_bn=args.sync_bn,
        #                        freeze_bn=args.freeze_bn)
        self.decoder = Decoder(self.nclass, 'autodeeplab', args, False)
        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
        #        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        #
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(
            args.lr_scheduler, args.lr, args.epochs,
            len(self.train_loader))  #TODO: use min_lr ?

        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            encoder_output, low_level_feature = self.model(image)
            output = self.decoder(encoder_output, low_level_feature)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                encoder_output, low_level_feature = self.model(image)
                output = self.decoder(encoder_output, low_level_feature)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #5
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                raise NotImplementedError
                #if so, which trainloader to use?
                # weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass, 12, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)
        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # mixed precision
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2'
                                           or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module,
                                  torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape,
                                           dtype=module.running_var.dtype,
                                           device=module.running_var.device),
                                requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape,
                                            dtype=module.running_var.dtype,
                                            device=module.running_var.device),
                                requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer,
                         self.architect_optimizer] = amp.initialize(
                             self.model,
                             [self.optimizer, self.architect_optimizer],
                             opt_level=self.opt_level,
                             keep_batchnorm_fp32=keep_batchnorm_fp32,
                             loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) > 1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print(
                    'currently cannot run with nn.DataParallel and optimization level',
                    self.opt_level)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        #checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1 or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(),
                                    checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.state_dict(),
                                    checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()

            if epoch >= self.args.alpha_epoch:
                search = next(iter(self.train_loaderB))
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda(
                    ), target_search.cuda()

                self.architect_optimizer.zero_grad()
                output_search = self.model(image_search)
                arch_loss = self.criterion(output_search, target_search)
                if self.use_amp:
                    with amp.scale_loss(
                            arch_loss,
                            self.architect_optimizer) as arch_scaled_loss:
                        arch_scaled_loss.backward()
                else:
                    arch_loss.backward()
                self.architect_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            #self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

            #torch.cuda.empty_cache()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        if args.dataset == 'CamVid':
            size = 512
            train_file = os.path.join(os.getcwd() + "\\data\\CamVid", "train.csv")
            val_file = os.path.join(os.getcwd() + "\\data\\CamVid", "val.csv")
            print('=>loading datasets')
            train_data = CamVidDataset(csv_file=train_file, phase='train')
            self.train_loader = torch.utils.data.DataLoader(train_data,
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.num_workers)
            val_data = CamVidDataset(csv_file=val_file, phase='val', flip_rate=0)
            self.val_loader = torch.utils.data.DataLoader(val_data,
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.num_workers)
            self.num_class = 32
        elif args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2,4,23,3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'CamVid':
            self.criterion = nn.CrossEntropyLoss()
        elif args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume:
            output_dir = os.path.join(args.save_dir, args.dataset, args.checkname)
            runs = sorted(glob.glob(os.path.join(output_dir, 'experiment_*')))
            run_id = int(runs[-1].split('_')[-1]) - 1 if runs else 0
            experiment_dir = os.path.join(output_dir, 'experiment_{}'.format(str(run_id)))
            load_name = os.path.join(experiment_dir,
                                 'checkpoint.pth.tar')
            if not os.path.isfile(load_name):
                raise RuntimeError("=> no checkpoint found at '{}'".format(load_name))
            checkpoint = torch.load(load_name)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.lr = checkpoint['optimizer']['param_groups'][0]['lr']
            print("=> loaded checkpoint '{}'(epoch {})".format(load_name, checkpoint['epoch']))

        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        # tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        if self.lr_staget_ind > 1 and epoch % (self.lr_stage[self.lr_staget_ind]) == 0:
            adjust_learning_rate(self.optimizer, self.args.lr_decay_gamma)
            self.lr *= self.args.lr_decay_gamma
            self.lr_staget_ind += 1
        for iteration, batch in enumerate(self.train_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            inputs = Variable(image)
            labels = Variable(target)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels.long())
            loss_val = loss.item()
            loss.backward(torch.ones_like(loss))
            # loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            # tbar.set_description('\rTrain loss:%.3f' % (train_loss / (iteration + 1)))

            if iteration % 10 == 0:
                print("Epoch[{}]({}/{}):Loss:{:.4f}, learning rate={}".format(epoch, iteration, len(self.train_loader), loss.data, self.lr))

            self.writer.add_scalar('train/total_loss_iter', loss.item(), iteration + num_img_tr * epoch)

            #if iteration % (num_img_tr // 10) == 0:
            #    global_step = iteration + num_img_tr * epoch
            #    self.summary.visualize_image(self.witer, self.args.dataset, image, target, outputs, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for iter, batch in enumerate(self.val_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f ' % (test_loss / (iter + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/FWIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iter * self.args.batch_size + image.shape[0]))
        print("Acc:{:.5f}, Acc_class:{:.5f}, mIoU:{:.5f}, fwIoU:{:.5f}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #7
0
class Trainer(object):
    def __init__(self,
                 batch_size=32,
                 optimizer_name="Adam",
                 lr=1e-3,
                 weight_decay=1e-5,
                 epochs=200,
                 model_name="model01",
                 gpu_ids=None,
                 resume=None,
                 tqdm=None,
                 is_develop=False):
        """
        args:
            batch_size = (int) batch_size of training and validation
            lr = (float) learning rate of optimization
            weight_decay = (float) weight decay of optimization
            epochs = (int) The number of epochs of training
            model_name = (string) The name of training model. Will be folder name.
            gpu_ids = (List) List of gpu_ids. (e.g. gpu_ids = [0, 1]). Use CPU, if it is None. 
            resume = (Dict) Dict of some settings. (resume = {"checkpoint_path":PATH_of_checkpoint, "fine_tuning":True or False}). 
                     Learn from scratch, if it is None.
            tqdm = (tqdm Object) progress bar object. Set your tqdm please.
                   Don't view progress bar, if it is None.
        """
        # Set params
        self.batch_size = batch_size
        self.epochs = epochs
        self.start_epoch = 0
        self.use_cuda = (gpu_ids is not None) and torch.cuda.is_available
        self.tqdm = tqdm
        self.use_tqdm = tqdm is not None
        # Define Utils. (No need to Change.)
        """
        These are Project Modules.
        You may not have to change these.
        
        Saver: Save model weight. / <utils.saver.Saver()>
        TensorboardSummary: Write tensorboard file. / <utils.summaries.TensorboardSummary()>
        Evaluator: Calculate some metrics (e.g. Accuracy). / <utils.metrics.Evaluator()>
        """
        ## ***Define Saver***
        self.saver = Saver(model_name, lr, epochs)
        self.saver.save_experiment_config()

        ## ***Define Tensorboard Summary***
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # ------------------------- #
        # Define Training components. (You have to Change!)
        """
        These are important setting for training.
        You have to change these.
        
        make_data_loader: This creates some <Dataloader>s. / <dataloader.__init__>
        Modeling: You have to define your Model. / <modeling.modeling.Modeling()>
        Evaluator: You have to define Evaluator. / <utils.metrics.Evaluator()>
        Optimizer: You have to define Optimizer. / <utils.optimizer.Optimizer()>
        Loss: You have to define Loss function. / <utils.loss.Loss()>
        """
        ## ***Define Dataloader***
        self.train_loader, self.val_loader, self.test_loader, self.num_classes = make_data_loader(
            batch_size, is_develop=is_develop)

        ## ***Define Your Model***
        self.model = Modeling(self.num_classes)

        ## ***Define Evaluator***
        self.evaluator = Evaluator(self.num_classes)

        ## ***Define Optimizer***
        self.optimizer = Optimizer(self.model.parameters(),
                                   optimizer_name=optimizer_name,
                                   lr=lr,
                                   weight_decay=weight_decay)

        ## ***Define Loss***
        self.criterion = SegmentationLosses(
            weight=torch.tensor([1.0, 1594.0]).cuda()).build_loss('ce')
        # self.criterion = SegmentationLosses().build_loss('focal')
        #  self.criterion = BCEDiceLoss()
        # ------------------------- #
        # Some settings
        """
        You don't have to touch bellow code.
        
        Using cuda: Enable to use cuda if you want.
        Resuming checkpoint: You can resume training if you want.
        """
        ## ***Using cuda***
        if self.use_cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=gpu_ids).cuda()

        ## ***Resuming checkpoint***
        """You can ignore bellow code."""
        self.best_pred = 0.0
        if resume is not None:
            if not os.path.isfile(resume["checkpoint_path"]):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    resume["checkpoint_path"]))
            checkpoint = torch.load(resume["checkpoint_path"])
            self.start_epoch = checkpoint['epoch']
            if self.use_cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if resume["fine_tuning"]:
                # resume params of optimizer, if run fine tuning.
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.start_epoch = 0
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume["checkpoint_path"], checkpoint['epoch']))

    def _run_epoch(self,
                   epoch,
                   mode="train",
                   leave_progress=True,
                   use_optuna=False):
        """
        run training or validation 1 epoch.
        You don't have to change almost of this method.
        
        args:
            epoch = (int) How many epochs this time.
            mode = {"train" or "val"}
            leave_progress = {True or False} Can choose whether leave progress bar or not.
            use_optuna = {True or False} Can choose whether use optuna or not.
        
        Change point (if you need):
        - Evaluation: You can change metrics of monitoring.
        - writer.add_scalar: You can change metrics to be saved in tensorboard.
        """
        # ------------------------- #
        leave_progress = leave_progress and not use_optuna
        # Initializing
        epoch_loss = 0.0
        ## Set model mode & tqdm (progress bar; it wrap dataloader)
        assert (mode == "train") or (
            mode == "val"
        ), "argument 'mode' can be 'train' or 'val.' Not {}.".format(mode)
        if mode == "train":
            data_loader = self.tqdm(
                self.train_loader,
                leave=leave_progress) if self.use_tqdm else self.train_loader
            self.model.train()
            num_dataset = len(self.train_loader)
        elif mode == "val":
            data_loader = self.tqdm(
                self.val_loader,
                leave=leave_progress) if self.use_tqdm else self.val_loader
            self.model.eval()
            num_dataset = len(self.val_loader)
        ## Reset confusion matrix of evaluator
        self.evaluator.reset()

        # ------------------------- #
        # Run 1 epoch
        for i, sample in enumerate(data_loader):
            ## ***Get Input data***
            inputs, target = sample["input"], sample["label"]
            if self.use_cuda:
                inputs, target = inputs.cuda(), target.cuda()

            ## ***Calculate Loss <Train>***
            if mode == "train":
                self.optimizer.zero_grad()
                output = self.model(inputs)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
            ## ***Calculate Loss <Validation>***
            elif mode == "val":
                with torch.no_grad():
                    output = self.model(inputs)
                loss = self.criterion(output, target)
            epoch_loss += loss.item()
            ## ***Report results***
            if self.use_tqdm:
                data_loader.set_description('{} loss: {:.3f}'.format(
                    mode, epoch_loss / (i + 1)))
            ## ***Add batch results into evaluator***
            target = target.cpu().numpy()
            output = torch.argmax(output, axis=1).data.cpu().numpy()
            self.evaluator.add_batch(target, output)

        ## **********Evaluate Score**********
        """You can add new metrics! <utils.metrics.Evaluator()>"""
        # Acc = self.evaluator.Accuracy()
        miou = self.evaluator.Mean_Intersection_over_Union()

        if not use_optuna:
            ## ***Save eval into Tensorboard***
            self.writer.add_scalar('{}/loss_epoch'.format(mode),
                                   epoch_loss / (i + 1), epoch)
            # self.writer.add_scalar('{}/Acc'.format(mode), Acc, epoch)
            self.writer.add_scalar('{}/miou'.format(mode), miou, epoch)
            print('Total {} loss: {:.3f}'.format(mode,
                                                 epoch_loss / num_dataset))
            print("{0} mIoU:{1:.2f}".format(mode, miou))

        # Return score to watch. (update checkpoint or optuna's objective)
        return miou

    def run(self, leave_progress=True, use_optuna=False):
        """
        Run all epochs of training and validation.
        """
        for epoch in tqdm(range(self.start_epoch, self.epochs)):
            print(pycolor.GREEN + "[Epoch: {}]".format(epoch) + pycolor.END)

            ## ***Train***
            print(pycolor.YELLOW + "Training:" + pycolor.END)
            self._run_epoch(epoch,
                            mode="train",
                            leave_progress=leave_progress,
                            use_optuna=use_optuna)
            ## ***Validation***
            print(pycolor.YELLOW + "Validation:" + pycolor.END)
            score = self._run_epoch(epoch,
                                    mode="val",
                                    leave_progress=leave_progress,
                                    use_optuna=use_optuna)
            print("---------------------")
            if score > self.best_pred:
                print("model improve best score from {:.4f} to {:.4f}.".format(
                    self.best_pred, score))
                self.best_pred = score
                self.saver.save_checkpoint({
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    self.model.state_dict(),
                    'optimizer':
                    self.optimizer.state_dict(),
                    'best_pred':
                    self.best_pred,
                })
        self.writer.close()
        return self.best_pred
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        if args.densecrfloss >0:
            self.densecrflosslayer = DenseCRFLoss(weight=args.densecrfloss, sigma_rgb=args.sigma_rgb, sigma_xy=args.sigma_xy, scale_factor=args.rloss_scale)
            print(self.densecrflosslayer)
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_celoss = 0.0
        train_crfloss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        softmax = nn.Softmax(dim=1)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            croppings = (target!=254).float()
            target[target==254]=255
            # Pixels labeled 255 are those unlabeled pixels. Padded region are labeled 254.
            # see function RandomScaleCrop in dataloaders/custom_transforms.py for the detail in data preprocessing
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            
            celoss = self.criterion(output, target)
            
            if self.args.densecrfloss ==0:
                loss = celoss
            else:
                max_output = (max(torch.abs(torch.max(output)), 
                                  torch.abs(torch.min(output))))
                mean_output = torch.mean(torch.abs(output)).item()
                # std_output = torch.std(output).item()
                probs = softmax(output) # /max_output*4
                denormalized_image = denormalizeimage(sample['image'], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                densecrfloss = self.densecrflosslayer(denormalized_image,probs,croppings)
                if self.args.cuda:
                    densecrfloss = densecrfloss.cuda()
                loss = celoss + densecrfloss
                train_crfloss += densecrfloss.item()

                logits_copy = output.detach().clone().requires_grad_(True)
                max_output_copy = (max(torch.abs(torch.max(logits_copy)), 
                                  torch.abs(torch.min(logits_copy))))
                probs_copy = softmax(logits_copy) # /max_output_copy*4
                denormalized_image_copy = denormalized_image.detach().clone()
                croppings_copy = croppings.detach().clone()
                densecrfloss_copy = self.densecrflosslayer(denormalized_image_copy, probs_copy, croppings)

                @torch.no_grad()
                def add_grad_map(grad, plot_name):
                  if i % (num_img_tr // 10) == 0:
                    global_step = i + num_img_tr * epoch
                    batch_grads = torch.max(torch.abs(grad), dim=1)[0].detach().cpu().numpy()
                    color_imgs = []
                    for grad_img in batch_grads:
                        grad_img[0,0]=0
                        img = colorize(grad_img)[:,:,:3]
                        color_imgs.append(img)
                    color_imgs = torch.from_numpy(np.array(color_imgs).transpose([0, 3, 1, 2]))
                    grid_image = make_grid(color_imgs[:3], 3, normalize=False, range=(0, 255))
                    self.writer.add_image(plot_name, grid_image, global_step)

                output.register_hook(lambda grad: add_grad_map(grad, 'Grad Logits')) 
                probs.register_hook(lambda grad: add_grad_map(grad, 'Grad Probs')) 
                
                logits_copy.register_hook(lambda grad: add_grad_map(grad, 'Grad Logits Rloss')) 
                densecrfloss_copy.backward()

                if i % (num_img_tr // 10) == 0:
                  global_step = i + num_img_tr * epoch
                  img_entropy = torch.sum(-probs*torch.log(probs+1e-9), dim=1).detach().cpu().numpy()
                  color_imgs = []
                  for e in img_entropy:
                      e[0,0] = 0
                      img = colorize(e)[:,:,:3]
                      color_imgs.append(img)
                  color_imgs = torch.from_numpy(np.array(color_imgs).transpose([0, 3, 1, 2]))
                  grid_image = make_grid(color_imgs[:3], 3, normalize=False, range=(0, 255))
                  self.writer.add_image('Entropy', grid_image, global_step)

                  self.writer.add_histogram('train/total_loss_iter/logit_histogram', output, i + num_img_tr * epoch)
                  self.writer.add_histogram('train/total_loss_iter/probs_histogram', probs, i + num_img_tr * epoch)

                self.writer.add_scalar('train/total_loss_iter/rloss', densecrfloss.item(), i + num_img_tr * epoch)
                self.writer.add_scalar('train/total_loss_iter/max_output', max_output.item(), i + num_img_tr * epoch)
                self.writer.add_scalar('train/total_loss_iter/mean_output', mean_output, i + num_img_tr * epoch)


            loss.backward()
        
            self.optimizer.step()
            train_loss += loss.item()
            train_celoss += celoss.item()
            
            tbar.set_description('Train loss: %.3f = CE loss %.3f + CRF loss: %.3f' 
                             % (train_loss / (i + 1),train_celoss / (i + 1),train_crfloss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/total_loss_iter/ce', celoss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        #if self.args.no_val:
        if self.args.save_interval:
            # save checkpoint every interval epoch
            is_best = False
            if (epoch + 1) % self.args.save_interval == 0:
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best, filename='checkpoint_epoch_{}.pth.tar'.format(str(epoch+1)))


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            target[target==254]=255
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
def main():

    parser = argparse.ArgumentParser(
        description="PyTorch DeeplabV3Plus Training")
    parser.add_argument('--backbone',
                        type=str,
                        default='resnet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')
    parser.add_argument('--out-stride',
                        type=int,
                        default=16,
                        help='network output stride (default: 16)')
    parser.add_argument('--dataset',
                        type=str,
                        default='active_cityscapes_image',
                        choices=[
                            'active_cityscapes_image',
                            'active_cityscapes_region', 'active_pascal_image',
                            'active_pascal_region'
                        ],
                        help='dataset name (default: active_cityscapes)')
    parser.add_argument('--use-sbd',
                        action='store_true',
                        default=False,
                        help='whether to use SBD dataset (default: False)')
    parser.add_argument('--base-size',
                        type=int,
                        default=513,
                        help='base image size')
    parser.add_argument('--crop-size',
                        type=int,
                        default=513,
                        help='crop image size')
    parser.add_argument('--sync-bn',
                        type=bool,
                        default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='whether to freeze bn parameters (default: False)')
    parser.add_argument('--loss-type',
                        type=str,
                        default='ce',
                        choices=['ce', 'focal'],
                        help='loss func type (default: ce)')
    parser.add_argument('--workers', type=int, default=4, help='num workers')
    # training hyper params
    parser.add_argument('--epochs',
                        type=int,
                        default=None,
                        metavar='N',
                        help='number of epochs to train (default: auto)')
    parser.add_argument('--start_epoch',
                        type=int,
                        default=0,
                        metavar='N',
                        help='start epochs (default:0)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=None,
                        metavar='N',
                        help='input batch size for \
								training (default: auto)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=None,
                        metavar='N',
                        help='input batch size for \
								testing (default: auto)')
    parser.add_argument(
        '--use-balanced-weights',
        action='store_true',
        default=False,
        help='whether to use balanced weights (default: False)')
    # optimizer params
    parser.add_argument('--lr',
                        type=float,
                        default=None,
                        metavar='LR',
                        help='learning rate (default: auto)')
    parser.add_argument('--lr-scheduler',
                        type=str,
                        default='poly',
                        choices=['poly', 'step', 'cos'],
                        help='lr scheduler mode: (default: poly)')
    parser.add_argument('--use-lr-scheduler',
                        default=False,
                        help='use learning rate scheduler',
                        action='store_true')
    parser.add_argument('--optimizer',
                        type=str,
                        default='SGD',
                        choices=['SGD', 'Adam'])
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='momentum (default: 0.9)')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=5e-4,
                        metavar='M',
                        help='w-decay (default: 5e-4)')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=False,
                        help='whether use nesterov (default: False)')
    # cuda, seed and logging
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--gpu-ids',
                        type=str,
                        default='0',
                        help='use which gpu to train, must be a \
						comma-separated list of integers only (default=0)')
    parser.add_argument('--seed',
                        type=int,
                        default=-1,
                        metavar='S',
                        help='random seed (default: 1)')
    # checking point
    parser.add_argument('--resume',
                        type=int,
                        default=0,
                        help='iteration to resume from')
    parser.add_argument('--checkname',
                        type=str,
                        default=None,
                        help='set the checkpoint name')
    parser.add_argument('--resume-selections',
                        type=str,
                        default=None,
                        help='resume selections file')
    # finetuning pre-trained models
    parser.add_argument('--ft',
                        action='store_true',
                        default=False,
                        help='finetuning on a different dataset')
    # evaluation option
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='evaluuation interval (default: 1)')
    parser.add_argument('--no-val',
                        action='store_true',
                        default=False,
                        help='skip validation during training')
    parser.add_argument('--overfit',
                        action='store_true',
                        default=False,
                        help='overfit to one sample')
    parser.add_argument('--seed_set',
                        action='store_true',
                        default='set_0.txt',
                        help='initial labeled set')
    parser.add_argument('--active-batch-size',
                        type=int,
                        default=50,
                        help='batch size queried from oracle')
    parser.add_argument('--active-selection-mode',
                        type=str,
                        default='random',
                        choices=[
                            'random', 'variance', 'coreset', 'ceal_confidence',
                            'ceal_margin', 'ceal_entropy', 'ceal_fusion',
                            'ceal_entropy_weakly_labeled',
                            'variance_representative', 'noise_image',
                            'noise_feature', 'noise_variance',
                            'accuracy_labels', 'accuracy_eval'
                        ],
                        help='method to select new samples')
    parser.add_argument('--active-region-size',
                        type=int,
                        default=129,
                        help='size of regions in case region dataset is used')
    parser.add_argument('--max-iterations',
                        type=int,
                        default=1000,
                        help='maximum active selection iterations')
    parser.add_argument(
        '--min-improvement',
        type=float,
        default=0.01,
        help='min improvement evaluation interval (default: 1)')
    parser.add_argument('--weak-label-entropy-threshold',
                        type=float,
                        default=0.80,
                        help='initial threshold for entropy for weak labels')
    parser.add_argument('--weak-label-threshold-decay',
                        type=float,
                        default=0.015,
                        help='decay for threshold on weak labels')
    parser.add_argument('--monitor-directory', type=str, default=None)
    parser.add_argument('--memory-hog',
                        action='store_true',
                        default=False,
                        help='memory_hog mode')
    parser.add_argument('--no-early-stop',
                        action='store_true',
                        default=False,
                        help='no early stopping')
    parser.add_argument('--architecture',
                        type=str,
                        default='deeplab',
                        choices=['deeplab', 'enet', 'fastscnn'])

    args = parser.parse_args()

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        try:
            args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
        except ValueError:
            raise ValueError(
                'Argument --gpu_ids must be a comma-separated list of integers only'
            )

    if args.sync_bn is None:
        if args.cuda and len(args.gpu_ids) > 1:
            args.sync_bn = True
        else:
            args.sync_bn = False

    # default settings for epochs, batch_size and lr
    if args.epochs is None:
        epoches = {
            'coco': 30,
            'cityscapes': 200,
            'active_cityscapes': 200,
            'pascal': 50,
        }
        args.epochs = epoches[args.dataset.lower()]

    if args.batch_size is None:
        args.batch_size = 4 * len(args.gpu_ids)

    if args.test_batch_size is None:
        args.test_batch_size = args.batch_size

    if args.lr is None:
        lrs = {
            'coco': 0.1,
            'cityscapes': 0.01,
            'active_cityscapes': 0.01,
            'pascal': 0.007,
        }
        args.lr = lrs[args.dataset.lower()] / (
            4 * len(args.gpu_ids)) * args.batch_size

    if args.checkname is None:
        args.checkname = 'deeplab-' + str(args.backbone)

    mc_dropout = args.active_selection_mode == 'variance' or args.active_selection_mode == 'variance_representative' or args.active_selection_mode == 'noise_variance'
    args.active_batch_size = args.active_batch_size * 2 if args.active_selection_mode == 'variance_representative' else args.active_batch_size

    print()
    print(args)

    # manual seeding
    if args.seed == -1:
        args.seed = int(random.random() * 2000)
    print('Using random seed = ', args.seed)
    torch.manual_seed(args.seed)

    kwargs = {
        'pin_memory': False,
        'init_set': args.seed_set,
        'memory_hog': args.memory_hog
    }
    dataloaders = make_dataloader(args.dataset, args.base_size, args.crop_size,
                                  args.batch_size, args.workers, args.overfit,
                                  **kwargs)

    training_set = dataloaders[0]
    dataloaders = dataloaders[1:]

    saver = Saver(args, remove_existing=False)
    saver.save_experiment_config()

    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()

    print()

    active_selector = get_active_selection_class(args.active_selection_mode,
                                                 training_set.NUM_CLASSES,
                                                 training_set.env,
                                                 args.crop_size,
                                                 args.batch_size)
    max_subset_selector = get_max_subset_active_selector(
        training_set.env, args.crop_size,
        args.batch_size)  # used only for representativeness cases

    total_active_selection_iterations = min(
        len(training_set.image_paths) // args.active_batch_size - 1,
        args.max_iterations)

    if args.resume != 0 and args.resume_selections != None:
        seed_size = len(training_set)
        with open(os.path.join(saver.experiment_dir, args.resume_selections),
                  "r") as fptr:
            paths = [
                u'{}'.format(x.strip()).encode('ascii')
                for x in fptr.readlines() if x is not ''
            ]
        training_set.expand_training_set(paths[seed_size:])
        assert len(training_set) == (args.resume * args.active_batch_size +
                                     seed_size)

    assert args.eval_interval <= args.epochs and args.epochs % args.eval_interval == 0

    trainer = Trainer(args, dataloaders, mc_dropout)
    trainer.initialize()

    for selection_iter in range(args.resume,
                                total_active_selection_iterations):

        print(
            f'ActiveIteration-{selection_iter:03d}/{total_active_selection_iterations:03d}'
        )

        fraction_of_data_labeled = round(
            training_set.get_fraction_of_labeled_data() * 100)

        if args.dataset.endswith('_image'):
            trainer.setup_saver_and_summary(fraction_of_data_labeled,
                                            training_set.current_image_paths)
        elif args.dataset.endswith('_region'):
            trainer.setup_saver_and_summary(
                fraction_of_data_labeled,
                training_set.current_image_paths,
                regions=[
                    training_set.current_paths_to_regions_map[x]
                    for x in training_set.current_image_paths
                ])
        else:
            raise NotImplementedError

        len_dataset_before = len(training_set)
        training_set.make_dataset_multiple_of_batchsize(args.batch_size)
        print(
            f'\nExpanding training set with {len_dataset_before}  images to {len(training_set)} images'
        )

        trainer.initialize()

        if not args.no_early_stop:
            early_stop = EarlyStopChecker(patience=5,
                                          min_improvement=args.min_improvement)

        best_mIoU = 0
        best_Acc = 0
        best_Acc_class = 0
        best_FWIoU = 0

        for outer_epoch in range(args.epochs // args.eval_interval):
            train_loss = 0
            for inner_epoch in range(args.eval_interval):
                train_loss += trainer.training(outer_epoch *
                                               args.eval_interval +
                                               inner_epoch)
            test_loss, mIoU, Acc, Acc_class, FWIoU, visualizations = trainer.validation(
                outer_epoch * args.eval_interval + inner_epoch)
            if mIoU > best_mIoU:
                best_mIoU = mIoU
            if Acc > best_Acc:
                best_Acc = Acc
            if Acc_class > best_Acc_class:
                best_Acc_class = Acc_class
            if FWIoU > best_FWIoU:
                best_FWIoU = FWIoU

            if not args.no_early_stop:
                # check for early stopping
                if early_stop(mIoU):
                    print(
                        f'Early stopping triggered after {outer_epoch * args.eval_interval + inner_epoch} epochs'
                    )
                    break

        training_set.reset_dataset()

        writer.add_scalar('active_loop/train_loss',
                          train_loss / len(training_set),
                          fraction_of_data_labeled)
        writer.add_scalar('active_loop/val_loss', test_loss,
                          fraction_of_data_labeled)
        writer.add_scalar('active_loop/mIoU', best_mIoU,
                          fraction_of_data_labeled)
        writer.add_scalar('active_loop/Acc', best_Acc,
                          fraction_of_data_labeled)
        writer.add_scalar('active_loop/Acc_class', best_Acc_class,
                          fraction_of_data_labeled)
        writer.add_scalar('active_loop/fwIoU', best_FWIoU,
                          fraction_of_data_labeled)

        summary.visualize_image(writer, args.dataset, visualizations[0],
                                visualizations[1], visualizations[2],
                                len(training_set.current_image_paths))

        trainer.writer.close()

        if selection_iter == (total_active_selection_iterations - 1):
            break

        checkpoint = torch.load(
            os.path.join(trainer.saver.experiment_dir, 'best.pth.tar'))
        trainer.model.module.load_state_dict(checkpoint['state_dict'])

        trainer.model.eval()

        if args.active_selection_mode == 'random':
            training_set.expand_training_set(
                active_selector.get_random_uncertainity(
                    training_set.remaining_image_paths,
                    args.active_batch_size))
        elif args.active_selection_mode == 'variance' or args.active_selection_mode == 'variance_representative':
            if args.dataset.endswith('_image'):
                print('Calculating entropies..')
                selected_images = active_selector.get_vote_entropy_for_images(
                    trainer.model, training_set.remaining_image_paths,
                    args.active_batch_size)
                if args.active_selection_mode == 'variance_representative':
                    selected_images = max_subset_selector.get_representative_images(
                        trainer.model, training_set.image_paths,
                        selected_images)
                training_set.expand_training_set(selected_images)
            elif args.dataset.endswith('_region'):
                print('Creating region maps..')
                regions, counts = active_selector.create_region_maps(
                    trainer.model, training_set.image_paths,
                    training_set.get_existing_region_maps(),
                    args.active_region_size, args.active_batch_size)

                if args.active_selection_mode == 'variance_representative':
                    regions, counts = max_subset_selector.get_representative_regions(
                        trainer.model, training_set.image_paths, regions,
                        args.active_region_size)
                print(
                    f'Got {counts}/{math.ceil((args.active_batch_size) * args.crop_size * args.crop_size / (args.active_region_size * args.active_region_size))} regions'
                )
                training_set.expand_training_set(
                    regions,
                    counts * args.active_region_size * args.active_region_size)
            else:
                raise NotImplementedError
        elif args.active_selection_mode == 'coreset':
            assert args.dataset.endswith(
                '_image'), 'only images supported for coreset approach'
            training_set.expand_training_set(
                active_selector.get_k_center_greedy_selections(
                    args.active_batch_size, trainer.model,
                    training_set.remaining_image_paths,
                    training_set.current_image_paths))
        elif args.active_selection_mode == 'ceal_confidence':
            training_set.expand_training_set(
                active_selector.get_least_confident_samples(
                    trainer.model, training_set.remaining_image_paths,
                    args.active_batch_size))
        elif args.active_selection_mode == 'ceal_margin':
            training_set.expand_training_set(
                active_selector.get_least_margin_samples(
                    trainer.model, training_set.remaining_image_paths,
                    args.active_batch_size))
        elif args.active_selection_mode == 'ceal_entropy':
            training_set.expand_training_set(
                active_selector.get_maximum_entropy_samples(
                    trainer.model, training_set.remaining_image_paths,
                    args.active_batch_size)[0])
        elif args.active_selection_mode == 'ceal_fusion':
            training_set.expand_training_set(
                active_selector.
                get_fusion_of_confidence_margin_entropy_samples(
                    trainer.model, training_set.remaining_image_paths,
                    args.active_batch_size))
        elif args.active_selection_mode == 'ceal_entropy_weakly_labeled':
            selected_samples, entropies = active_selector.get_maximum_entropy_samples(
                trainer.model, training_set.remaining_image_paths,
                args.active_batch_size)
            training_set.clear_weak_labels()
            weak_labels = active_selector.get_weakly_labeled_data(
                trainer.model, training_set.remaining_image_paths,
                args.weak_label_entropy_threshold -
                selection_iter * args.weak_label_threshold_decay, entropies)
            for sample in selected_samples:
                if sample in weak_labels:
                    del weak_labels[sample]

            training_set.expand_training_set(selected_samples)
            training_set.add_weak_labels(weak_labels)
        elif args.active_selection_mode == 'noise_image':
            print('Calculating entropies..')
            selected_images = active_selector.get_vote_entropy_for_images_with_input_noise(
                trainer.model, training_set.remaining_image_paths,
                args.active_batch_size)
            training_set.expand_training_set(selected_images)
        elif args.active_selection_mode == 'noise_feature':
            print('Calculating entropies..')
            selected_images = active_selector.get_vote_entropy_for_images_with_feature_noise(
                trainer.model, training_set.remaining_image_paths,
                args.active_batch_size)
            training_set.expand_training_set(selected_images)
        elif args.active_selection_mode == 'noise_variance':
            if args.dataset.endswith('_image'):
                print('Calculating entropies..')
                selected_images = active_selector.get_vote_entropy_for_batch_with_noise_and_vote_entropy(
                    trainer.model, training_set.remaining_image_paths,
                    args.active_batch_size)
                training_set.expand_training_set(selected_images)
            elif args.dataset.endswith('_region'):
                print('Creating region maps..')
                regions, counts = active_selector.create_region_maps(
                    trainer.model, training_set.image_paths,
                    training_set.get_existing_region_maps(),
                    args.active_region_size, args.active_batch_size)
                print(
                    f'Got {counts}/{math.ceil((args.active_batch_size) * args.crop_size * args.crop_size / (args.active_region_size * args.active_region_size))} regions'
                )
                training_set.expand_training_set(
                    regions,
                    counts * args.active_region_size * args.active_region_size)
        elif args.active_selection_mode == 'accuracy_labels':
            print('Evaluating accuracies..')
            selected_images = active_selector.get_least_accurate_sample_using_labels(
                trainer.model, training_set.remaining_image_paths,
                args.active_batch_size)
            training_set.expand_training_set(selected_images)
        elif args.active_selection_mode == 'accuracy_eval':
            full_monitor_directory = os.path.join(constants.RUNS, args.dataset,
                                                  args.monitor_directory)
            selections_file = os.path.join(
                full_monitor_directory,
                f'run_{round(training_set.get_next_est_fraction_of_labeled_data(args.active_batch_size) * 100):04d}',
                "selections.txt")
            print('Waiting for the next folder to be available..',
                  selections_file)
            selected_images = active_selector.wait_for_selected_samples(
                selections_file, training_set.remaining_image_paths)
            training_set.expand_training_set(selected_images)
        else:
            raise NotImplementedError

    writer.close()
Пример #10
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #11
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        if args.loss_type == 'depth_loss_two_distributions':
            self.nclass = args.num_class + args.num_class2 + 1
        if args.loss_type == 'depth_avg_sigmoid_class':
            self.nclass = args.num_class + args.num_class2
        # Define network

        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)


        print("\nDefine models...\n")
        self.model_aprox_depth = DeepLab(num_classes=1,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)

        self.input_conv = nn.Conv2d(4, 3, 3, padding=1)
        # Using cuda
        if args.cuda:
            self.model_aprox_depth = torch.nn.DataParallel(self.model_aprox_depth, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model_aprox_depth)
            self.model_aprox_depth = self.model_aprox_depth.cuda()
            self.input_conv = self.input_conv.cuda()


        print("\nLoad checkpoints...\n")
        if not args.cuda:
            ckpt_aprox_depth = torch.load(args.ckpt, map_location='cpu')
            self.model_aprox_depth.load_state_dict(ckpt_aprox_depth['state_dict'])
        else:
            ckpt_aprox_depth = torch.load(args.ckpt)
            self.model_aprox_depth.module.load_state_dict(ckpt_aprox_depth['state_dict'])


        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        # set optimizer
        optimizer = torch.optim.Adam(train_params, args.lr)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        if 'depth' in args.loss_type:
            self.criterion = DepthLosses(weight=weight,
                                         cuda=args.cuda,
                                         min_depth=args.min_depth,
                                         max_depth=args.max_depth,
                                         num_class=args.num_class,
                                         cut_point=args.cut_point,
                                         num_class2=args.num_class2).build_loss(mode=args.loss_type)
            self.infer = DepthLosses(weight=weight,
                                     cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.max_depth,
                                     num_class=args.num_class,
                                     cut_point=args.cut_point,
                                     num_class2=args.num_class2)
            self.evaluator_depth = EvaluatorDepth(args.batch_size)
        else:
            self.criterion = SegmentationLosses(cuda=args.cuda, weight=weight).build_loss(mode=args.loss_type)
            self.evaluator = Evaluator(self.nclass)

        self.model, self.optimizer = model, optimizer

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if 'depth' in args.loss_type:
            self.best_pred = 1e6
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            if not args.cuda:
                checkpoint = torch.load(args.resume, map_location='cpu')
            else:
                checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.module.load_state_dict(state_dict, strict=False)
            else:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.load_state_dict(state_dict, strict=False)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            if 'depth' in args.loss_type:
                self.best_pred = 1e6
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        # add input layer to the model
        self.model = nn.Sequential(
            self.input_conv,
            self.model
        )
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        self.model_aprox_depth.eval()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.dataset == 'apollo_seg' or self.args.dataset == 'farsight_seg':
                target[target <= self.args.cut_point] = 0
                target[target > self.args.cut_point] = 1
            if image.shape[0] == 1:
                target = torch.cat([target, target], dim=0)
                image = torch.cat([image, image], dim=0)
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            aprox_depth = self.model_aprox_depth(image)
            aprox_depth = self.infer.sigmoid(aprox_depth)
            input = torch.cat([image, aprox_depth], dim=1)
            output = self.model(input)
            if self.args.loss_type == 'depth_sigmoid_loss_inverse':
                loss = self.criterion(output, target, inverse=True)
            else:
                loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                target[
                    torch.isnan(target)] = 0  # change nan values to zero for display (handle warning from tensorboard)
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step,
                                             n_class=self.args.num_class)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.model_aprox_depth.eval()
        if 'depth' in self.args.loss_type:
            self.evaluator_depth.reset()
        else:
            softmax = nn.Softmax(1)
            self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                aprox_depth = self.model_aprox_depth(image)
                aprox_depth = self.infer.sigmoid(aprox_depth)
                input = torch.cat([image, aprox_depth], dim=1)
                output = self.model(input)
                loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
            if 'depth' in self.args.loss_type:
                if self.args.loss_type == 'depth_loss':
                    pred = self.infer.pred_to_continous_depth(output)
                elif self.args.loss_type == 'depth_avg_sigmoid_class':
                    pred = self.infer.pred_to_continous_depth_avg(output)
                elif self.args.loss_type == 'depth_loss_combination':
                    pred = self.infer.pred_to_continous_combination(output)
                elif self.args.loss_type == 'depth_loss_two_distributions':
                    pred = self.infer.pred_to_continous_depth_two_distributions(output)
                elif 'depth_sigmoid_loss' in self.args.loss_type:
                    output = self.infer.sigmoid(output.squeeze(1))
                    pred = self.infer.depth01_to_depth(output)
                # Add batch sample into evaluator
                self.evaluator_depth.evaluateError(pred, target)
            else:
                output = softmax(output)
                pred = output.data.cpu().numpy()
                target = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                # Add batch sample into evaluator
                self.evaluator.add_batch(target, pred)
        if 'depth' in self.args.loss_type:
            # Fast test during the training
            MSE = self.evaluator_depth.averageError['MSE']
            RMSE = self.evaluator_depth.averageError['RMSE']
            ABS_REL = self.evaluator_depth.averageError['ABS_REL']
            LG10 = self.evaluator_depth.averageError['LG10']
            MAE = self.evaluator_depth.averageError['MAE']
            DELTA1 = self.evaluator_depth.averageError['DELTA1']
            DELTA2 = self.evaluator_depth.averageError['DELTA2']
            DELTA3 = self.evaluator_depth.averageError['DELTA3']

            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/MSE', MSE, epoch)
            self.writer.add_scalar('val/RMSE', RMSE, epoch)
            self.writer.add_scalar('val/ABS_REL', ABS_REL, epoch)
            self.writer.add_scalar('val/LG10', LG10, epoch)

            self.writer.add_scalar('val/MAE', MAE, epoch)
            self.writer.add_scalar('val/DELTA1', DELTA1, epoch)
            self.writer.add_scalar('val/DELTA2', DELTA2, epoch)
            self.writer.add_scalar('val/DELTA3', DELTA3, epoch)

            print('Validation:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print(
                "MSE:{}, RMSE:{}, ABS_REL:{}, LG10: {}\nMAE:{}, DELTA1:{}, DELTA2:{}, DELTA3: {}".format(MSE, RMSE,
                                                                                                         ABS_REL,
                                                                                                         LG10, MAE,
                                                                                                         DELTA1,
                                                                                                         DELTA2,
                                                                                                         DELTA3))
            new_pred = RMSE
            if new_pred < self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
        else:
            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/mIoU', mIoU, epoch)
            self.writer.add_scalar('val/Acc', Acc, epoch)
            self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
            self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
            print('Validation:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
            print('Loss: %.3f' % test_loss)

            new_pred = mIoU
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

        print('Loss: %.3f' % test_loss)
Пример #12
0
def main():

    parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training")
    parser.add_argument('--backbone', type=str, default='resnet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')
    parser.add_argument('--out-stride', type=int, default=16,
                        help='network output stride (default: 16)')
    parser.add_argument('--dataset', type=str, default='active_cityscapes_image',
                        choices=['active_cityscapes_image', 'active_cityscapes_region', 'active_pascal_image', 'active_pascal_region'],
                        help='dataset name (default: active_cityscapes)')
    parser.add_argument('--base-size', type=int, default=513,
                        help='base image size')
    parser.add_argument('--crop-size', type=int, default=513,
                        help='crop image size')
    parser.add_argument('--sync-bn', type=bool, default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument('--freeze-bn', type=bool, default=False,
                        help='whether to freeze bn parameters (default: False)')
    parser.add_argument('--loss-type', type=str, default='ce',
                        choices=['ce', 'focal'],
                        help='loss func type (default: ce)')
    parser.add_argument('--workers', type=int, default=4,
                        help='num workers')
    # training hyper params
    parser.add_argument('--epochs', type=int, default=None, metavar='N',
                        help='number of epochs to train (default: auto)')
    parser.add_argument('--start_epoch', type=int, default=0,
                        metavar='N', help='start epochs (default:0)')
    parser.add_argument('--batch-size', type=int, default=None,
                        metavar='N', help='input batch size for \
                                training (default: auto)')
    parser.add_argument('--use-balanced-weights', action='store_true', default=False,
                        help='whether to use balanced weights (default: False)')
    # optimizer params
    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (default: auto)')
    parser.add_argument('--lr-scheduler', type=str, default='poly',
                        choices=['poly', 'step', 'cos'],
                        help='lr scheduler mode: (default: poly)')
    parser.add_argument('--use-lr-scheduler', default=False, help='use learning rate scheduler', action='store_true')
    parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam'])
    parser.add_argument('--momentum', type=float, default=0.9,
                        metavar='M', help='momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=5e-4,
                        metavar='M', help='w-decay (default: 5e-4)')
    parser.add_argument('--nesterov', action='store_true', default=False,
                        help='whether use nesterov (default: False)')
    # cuda, seed and logging
    parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument('--gpu-ids', type=str, default='0',
                        help='use which gpu to train, must be a \
                        comma-separated list of integers only (default=0)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    # checking point
    parser.add_argument('--resume', type=int, default=0,
                        help='iteration to resume from')
    parser.add_argument('--checkname', type=str, default=None,
                        help='set the checkpoint name')
    parser.add_argument('--resume-selections', type=str, default=None,
                        help='resume selections file')
    # finetuning pre-trained models
    parser.add_argument('--ft', action='store_true', default=False,
                        help='finetuning on a different dataset')
    # evaluation option
    parser.add_argument('--eval-interval', type=int, default=1,
                        help='evaluuation interval (default: 1)')
    parser.add_argument('--no-val', action='store_true', default=False,
                        help='skip validation during training')
    parser.add_argument('--overfit', action='store_true', default=False,
                                            help='overfit to one sample')
    parser.add_argument('--seed_set', action='store_true', default='set_0.txt',
                        help='initial labeled set')
    parser.add_argument('--active-batch-size', type=int, default=50,
                        help='batch size queried from oracle')
    parser.add_argument('--active-region-size', type=int, default=129, help='size of regions in case region dataset is used')
    parser.add_argument('--max-iterations', type=int, default=1000, help='maximum active selection iterations')
    parser.add_argument('--min-improvement', type=float, default=0.01, help='min improvement evaluation interval (default: 1)')
    parser.add_argument('--weight-unet', type=float, default=0.30, help='unet loss weight')
    parser.add_argument('--weight-wrong-label-unet', type=float, default=0.75, help='unet loss weight')
    parser.add_argument('--accuracy-selection', type=str, default='softmax', choices=['softmax', 'argmax'], help='selection based on soft or hard scores')
    parser.add_argument('--memory-hog', action='store_true', default=False, help='memory_hog mode')
    parser.add_argument('--active-selection-mode', type=str, default='accuracy',
                        choices=['accuracy', 'gradient', 'uncertain', 'uncertain_gradient'], help='method to select new samples')
    parser.add_argument('--no-early-stop', action='store_true', default=False, help='no early stopping')
    parser.add_argument('--architecture', type=str, default='deeplab', choices=['deeplab', 'enet', 'fastscnn'])
    parser.add_argument('--no-end-to-end', action='store_true', default=False, help='no end to end training')
    parser.add_argument('--symmetry', action='store_true', default=False, help='both deeplabs')

    args = parser.parse_args()

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        try:
            args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
        except ValueError:
            raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')

    if args.sync_bn is None:
        if args.cuda and len(args.gpu_ids) > 1:
            args.sync_bn = True
        else:
            args.sync_bn = False

    # default settings for epochs, batch_size and lr
    if args.epochs is None:
        epoches = {
            'coco': 30,
            'cityscapes': 200,
            'active_cityscapes': 50,
            'pascal': 50,
        }
        args.epochs = epoches[args.dataset.lower()]

    if args.batch_size is None:
        args.batch_size = 4 * len(args.gpu_ids)

    if args.lr is None:
        lrs = {
            'coco': 0.1,
            'cityscapes': 0.01,
            'active_cityscapes': 0.01,
            'pascal': 0.007,
        }
        args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size

    print()
    print(args)

    w_dl = [1 - args.weight_unet] * args.epochs
    w_un = [args.weight_unet] * args.epochs

    if args.architecture == 'enet' or args.no_end_to_end:

        for i in range(0, args.epochs * 2 // 3):
            w_dl[i] = 1.0
            w_un[i] = 0.0

        for i in range(2 * args.epochs // 3, args.epochs):
            w_dl[i] = 0.0
            w_un[i] = 1.0

    kwargs = {'pin_memory': False, 'init_set': args.seed_set, 'memory_hog': args.memory_hog}
    dataloaders = make_dataloader(args.dataset, args.base_size, args.crop_size, args.batch_size, args.workers, args.overfit, **kwargs)

    training_set = dataloaders[0]
    dataloaders = dataloaders[1:]

    saver = Saver(args, remove_existing=False)
    saver.save_experiment_config()

    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()

    print()

    active_selector = get_active_selection_class('accuracy_labels', training_set.NUM_CLASSES, training_set.env, args.crop_size, args.batch_size)

    total_active_selection_iterations = min(len(training_set.image_paths) // args.active_batch_size - 1, args.max_iterations)

    if args.resume != 0 and args.resume_selections != None:
        seed_size = len(training_set)
        with open(os.path.join(saver.experiment_dir, args.resume_selections), "r") as fptr:
            paths = [u'{}'.format(x.strip()).encode('ascii') for x in fptr.readlines() if x is not '']
        training_set.expand_training_set(paths[seed_size:])
        assert len(training_set) == (args.resume * args.active_batch_size + seed_size)

    assert args.eval_interval <= args.epochs and args.epochs % args.eval_interval == 0

    trainer = Trainer(args, dataloaders)
    trainer.initialize()

    for selection_iter in range(args.resume, total_active_selection_iterations):

        print(f'ActiveIteration-{selection_iter:03d}/{total_active_selection_iterations:03d}')

        fraction_of_data_labeled = round(training_set.get_fraction_of_labeled_data() * 100)

        if args.dataset.endswith('_image'):
            trainer.setup_saver_and_summary(fraction_of_data_labeled, training_set.current_image_paths)
        elif args.dataset.endswith('_region'):
            trainer.setup_saver_and_summary(fraction_of_data_labeled, training_set.current_image_paths, regions=[
                                            training_set.current_paths_to_regions_map[x] for x in training_set.current_image_paths])
        else:
            raise NotImplementedError

        len_dataset_before = len(training_set)
        training_set.make_dataset_multiple_of_batchsize(args.batch_size)
        print(f'\nExpanding training set with {len_dataset_before}  images to {len(training_set)} images')

        trainer.initialize()

        if not args.no_early_stop:
            early_stop = EarlyStopChecker(patience=5, min_improvement=args.min_improvement)

        best_mIoU = 0
        best_Acc = 0
        best_Acc_class = 0
        best_FWIoU = 0

        for outer_epoch in range(args.epochs // args.eval_interval):
            train_loss = 0
            for inner_epoch in range(args.eval_interval):
                epoch = outer_epoch * args.eval_interval + inner_epoch
                train_loss += trainer.training(epoch, w_dl[epoch], w_un[epoch])
            test_loss, mIoU, Acc, Acc_class, FWIoU, visualizations = trainer.validation(epoch, w_dl[epoch], w_un[epoch])
            if mIoU > best_mIoU:
                best_mIoU = mIoU
            if Acc > best_Acc:
                best_Acc = Acc
            if Acc_class > best_Acc_class:
                best_Acc_class = Acc_class
            if FWIoU > best_FWIoU:
                best_FWIoU = FWIoU
            # check for early stopping
            if not args.no_early_stop:
                if early_stop(mIoU):
                    print(f'Early stopping triggered after {epoch} epochs')
                    break

        training_set.reset_dataset()

        writer.add_scalar('active_loop/train_loss', train_loss / len(training_set), fraction_of_data_labeled)
        writer.add_scalar('active_loop/val_loss', test_loss, fraction_of_data_labeled)
        writer.add_scalar('active_loop/mIoU', best_mIoU, fraction_of_data_labeled)
        writer.add_scalar('active_loop/Acc', best_Acc, fraction_of_data_labeled)
        writer.add_scalar('active_loop/Acc_class', best_Acc_class, fraction_of_data_labeled)
        writer.add_scalar('active_loop/fwIoU', best_FWIoU, fraction_of_data_labeled)

        summary.create_single_visualization(writer, f'active_loop', args.dataset, visualizations[0], visualizations[1], visualizations[
            2], visualizations[3], visualizations[4], len(training_set.current_image_paths))

        trainer.writer.close()
        trainer.model.eval()

        if args.active_selection_mode == 'accuracy':
            if args.dataset.endswith('_image'):
                print('Estimating accuracies..')
                selected_images = active_selector.get_least_accurate_samples(
                    trainer.model, training_set.remaining_image_paths, args.active_batch_size, args.accuracy_selection)
                training_set.expand_training_set(selected_images)
            elif args.dataset.endswith('_region'):
                print('Estimating accuracy regions..')
                regions, counts = active_selector.get_least_accurate_region_maps(
                    trainer.model, training_set.image_paths, training_set.get_existing_region_maps(), args.active_region_size, args.active_batch_size)
                print(f'Got {counts}/{math.ceil((args.active_batch_size) * args.crop_size * args.crop_size / (args.active_region_size * args.active_region_size))} regions')
                training_set.expand_training_set(regions, counts * args.active_region_size * args.active_region_size)
        elif args.active_selection_mode == 'gradient':
            print('Estimating gradients..')
            selected_images = active_selector.get_adversarially_vulnarable_samples(
                trainer.model, training_set.remaining_image_paths, args.active_batch_size)
            training_set.expand_training_set(selected_images)
        elif args.active_selection_mode == 'uncertain':
            print('Estimating uncertainities..')
            selected_images = active_selector.get_unsure_samples(
                trainer.model, training_set.remaining_image_paths, args.active_batch_size)
            training_set.expand_training_set(selected_images)
        elif args.active_selection_mode == 'uncertain_gradient':
            print('Estimating uncertainities..')
            selected_images = active_selector.get_unsure_samples(
                trainer.model, training_set.remaining_image_paths, args.active_batch_size * 2)
            print('Estimating gradients..')
            selected_images = active_selector.get_adversarially_vulnarable_samples(
                trainer.model, selected_images, args.active_batch_size)
            training_set.expand_training_set(selected_images)
        torch.cuda.empty_cache()
    writer.close()
Пример #13
0
class Trainer(object):

    def __init__(self, args, dataloaders):
        self.args = args
        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataloaders

    def setup_saver_and_summary(self, num_current_labeled_samples, samples, experiment_group=None, regions=None):

        self.saver = ActiveSaver(self.args, num_current_labeled_samples, experiment_group=experiment_group)
        self.saver.save_experiment_config()
        self.saver.save_active_selections(samples, regions)
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.num_current_labeled_samples = num_current_labeled_samples

    def initialize(self):

        args = self.args
        model = DeepLabAccuracyPredictor(num_classes=self.nclass, backbone=args.backbone, output_stride=args.out_stride,
                                         sync_bn=args.sync_bn, freeze_bn=args.freeze_bn, mc_dropout=False, enet=args.architecture == 'enet', symmetry=args.symmetry)

        train_params = model.get_param_list(args.lr, args.architecture == 'enet', args.symmetry)

        if args.optimizer == 'SGD':
            optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        if args.use_balanced_weights:
            weight = calculate_weights_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion_deeplab = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion_unet = SegmentationLosses(weight=torch.FloatTensor(
            [args.weight_wrong_label_unet, 1 - args.weight_wrong_label_unet]), cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.deeplab_evaluator = Evaluator(self.nclass)
        self.unet_evaluator = Evaluator(2)

        if args.use_lr_scheduler:
            self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader))
        else:
            self.scheduler = None

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        self.best_pred = 0.0

    def training(self, epoch, w_dl, w_un):

        train_loss = 0.0
        train_loss_unet = 0.0
        train_loss_deeplab = 0.0

        self.model.train()
        num_img_tr = len(self.train_loader)
        tbar = tqdm(self.train_loader, desc='\r')

        visualization_index = int(random.random() * len(self.train_loader))
        vis_img = None
        vis_tgt_dl = None
        vis_tgt_un = None
        vis_out_dl = None
        vis_out_un = None

        for i, sample in enumerate(tbar):
            image, deeplab_target = sample['image'], sample['label']

            if self.args.cuda:
                image, deeplab_target = image.cuda(), deeplab_target.cuda()
            if self.scheduler:
                self.scheduler(self.optimizer, i, epoch, self.best_pred)
                self.writer.add_scalar('train/learning_rate', self.scheduler.current_lr, i + num_img_tr * epoch)

            self.optimizer.zero_grad()
            deeplab_output, unet_output = self.model(image)
            unet_target = deeplab_output.argmax(1).squeeze() == deeplab_target.long()
            unet_target[deeplab_target == 255] = 255

            if i == visualization_index:
                vis_img = image.cpu()
                vis_tgt_dl = deeplab_target.cpu()
                vis_out_dl = deeplab_output.cpu()
                vis_tgt_un = unet_target.cpu()
                vis_out_un = unet_output.cpu()

            loss_deeplab = self.criterion_deeplab(deeplab_output, deeplab_target)
            loss_unet = self.criterion_unet(unet_output, unet_target)
            loss = w_dl * loss_deeplab + w_un * loss_unet
            loss.backward()
            self.optimizer.step()
            train_loss_deeplab += loss_deeplab.item()
            train_loss_unet += loss_unet.item()
            train_loss += loss.item()
            tbar.set_description('Train losses: %.2f(dl) + %.2f(un) = %.3f' %
                                 (train_loss_deeplab / (i + 1), train_loss_unet / (i + 1), train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter_dl', loss_deeplab.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/total_loss_iter_un', loss_unet.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

        self.summary.create_single_visualization(self.writer, f'train/run_{self.num_current_labeled_samples:04d}', self.args.dataset, vis_img, vis_tgt_dl, vis_out_dl, vis_tgt_un, vis_out_un, epoch)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        self.writer.add_scalar('train/total_loss_epoch_dl', train_loss_unet, epoch)
        self.writer.add_scalar('train/total_loss_epoch_un', train_loss_deeplab, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f (DeepLab) + %.3f (UNet) = %.3f' % (train_loss_deeplab, train_loss_unet, train_loss))
        print('BestPred: %.3f' % self.best_pred)

        self.writer.add_scalar('train/w_dl', w_dl, i + num_img_tr * epoch)
        self.writer.add_scalar('train/w_un', w_un, i + num_img_tr * epoch)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

        return train_loss

    def validation(self, epoch, w_dl, w_un):

        self.model.eval()
        self.deeplab_evaluator.reset()
        self.unet_evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        test_loss_unet = 0.0
        test_loss_deeplab = 0.0

        visualization_index = int(random.random() * len(self.val_loader))
        vis_img = None
        vis_tgt_dl = None
        vis_tgt_un = None
        vis_out_dl = None
        vis_out_un = None

        for i, sample in enumerate(tbar):
            image, deeplab_target = sample['image'], sample['label']

            if self.args.cuda:
                image, deeplab_target = image.cuda(), deeplab_target.cuda()

            with torch.no_grad():
                deeplab_output, unet_output = self.model(image)

            unet_target = deeplab_output.argmax(1).squeeze() == deeplab_target.long()
            unet_target[deeplab_target == 255] = 255

            if i == visualization_index:
                vis_img = image.cpu()
                vis_tgt_dl = deeplab_target.cpu()
                vis_out_dl = deeplab_output.cpu()
                vis_tgt_un = unet_target.cpu()
                vis_out_un = unet_output.cpu()

            loss_deeplab = self.criterion_deeplab(deeplab_output, deeplab_target)
            loss_unet = self.criterion_unet(unet_output, unet_target)
            loss = w_dl * loss_deeplab + w_un * loss_unet

            test_loss += loss.item()
            test_loss_unet += loss_unet.item()
            test_loss_deeplab += loss_deeplab.item()

            tbar.set_description('Test losses: %.2f(dl) + %.2f(un) = %.3f' %
                                 (test_loss_deeplab / (i + 1), test_loss_unet / (i + 1), test_loss / (i + 1)))
            pred = deeplab_output.data.cpu().numpy()
            deeplab_target = deeplab_target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.deeplab_evaluator.add_batch(deeplab_target, pred)
            self.unet_evaluator.add_batch(unet_target.cpu().numpy(), np.argmax(unet_output.cpu().numpy(), axis=1))

        # Fast test during the training
        Acc = self.deeplab_evaluator.Pixel_Accuracy()
        Acc_class = self.deeplab_evaluator.Pixel_Accuracy_Class()
        mIoU = self.deeplab_evaluator.Mean_Intersection_over_Union()
        FWIoU = self.deeplab_evaluator.Frequency_Weighted_Intersection_over_Union()
        UNetAcc = self.unet_evaluator.Pixel_Accuracy()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        self.writer.add_scalar('val/UNetAcc', UNetAcc, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, UNetAcc: {}".format(Acc, Acc_class, mIoU, FWIoU, UNetAcc))
        print('Loss: %.3f (DeepLab) + %.3f (UNet) = %.3f' % (test_loss_deeplab, test_loss_unet, test_loss))

        new_pred = mIoU
        is_best = False
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred

        # save every validation model (overwrites)
        self.saver.save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': self.model.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_pred': self.best_pred,
        }, is_best)

        return test_loss, mIoU, Acc, Acc_class, FWIoU, [vis_img, vis_tgt_dl, vis_out_dl, vis_tgt_un, vis_out_un]
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        # self.saver = Saver(args)
        # Recoder the running processing
        self.saver = Saver(args)
        sys.stdout = Logger(
            os.path.join(
                self.saver.experiment_dir,
                'log_train-%s.txt' % time.strftime("%Y-%m-%d-%H-%M-%S")))
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)
        if args.dataset == 'pairwise_lits':
            proxy_nclasses = self.nclass = 3
        elif args.dataset == 'pairwise_chaos':
            proxy_nclasses = 2 * self.nclass
        else:
            raise NotImplementedError

        # Define network
        model = ConsistentDeepLab(in_channels=3,
                                  num_classes=proxy_nclasses,
                                  pretrained=args.pretrained,
                                  backbone=args.backbone,
                                  output_stride=args.out_stride,
                                  sync_bn=args.sync_bn,
                                  freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        # optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(train_params,
                                     weight_decay=args.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            weights = calculate_weigths_labels(args.dataset, self.train_loader,
                                               proxy_nclasses)
        else:
            weights = None

        # Initializing loss
        print("Initializing loss: {}".format(args.loss_type))
        self.criterion = losses.init_loss(args.loss_type, weights=weights)

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, (sample1, sample2, proxy_label,
                sample_indices) in enumerate(tbar):
            image1, target1 = sample1['image'], sample1['label']
            image2, target2 = sample2['image'], sample2['label']
            if self.args.cuda:
                image1, target1 = image1.cuda(), target1.cuda()
                image2, target2 = image2.cuda(), target2.cuda()
                proxy_label = proxy_label.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image1, image2)
            loss = self.criterion(output, proxy_label)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                image = torch.cat((image1, image2), dim=-2)
                if len(proxy_label.shape) > 3:
                    output = output[:, 0:self.nclass]
                    proxy_label = torch.argmax(proxy_label[:, 0:self.nclass],
                                               dim=1)
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, proxy_label, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image1.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        val_time = 0
        for i, (sample1, sample2, proxy_label,
                sample_indices) in enumerate(tbar):
            image1, target1 = sample1['image'], sample1['label']
            image2, target2 = sample2['image'], sample2['label']
            if self.args.cuda:
                image1, target1 = image1.cuda(), target1.cuda()
                image2, target2 = image2.cuda(), target2.cuda()
                proxy_label = proxy_label.cuda()

            with torch.no_grad():
                start = time.time()
                output = self.model(image1, image2, is_val=True)
                end = time.time()
            val_time += end - start
            loss = self.criterion(output, proxy_label)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            proxy_label = proxy_label.cpu().numpy()

            # Add batch sample into evaluator
            if len(proxy_label.shape) > 3:
                pred = np.argmax(pred[:, 0:self.nclass], axis=1)
                proxy_label = np.argmax(proxy_label[:, 0:self.nclass], axis=1)
            else:
                pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(proxy_label, pred)

            if self.args.save_predict:
                self.saver.save_predict_mask(
                    pred, sample_indices, self.val_loader.dataset.data1_files)

        print("Val time: {}".format(val_time))
        print("Total paramerters: {}".format(
            sum(x.numel() for x in self.model.parameters())))
        if self.args.save_predict:
            namelist = []
            for fname in self.val_loader.dataset.data1_files:
                # namelist.append(fname.split('/')[-1].split('.')[0])
                _, name = os.path.split(fname)
                name = name.split('.')[0]
                namelist.append(name)
            file = gzip.open(
                os.path.join(self.saver.save_dir, 'namelist.pkl.gz'), 'wb')
            pickle.dump(namelist, file, protocol=-1)
            file.close()

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image1.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        dice = self.evaluator.Dice()
        # self.writer.add_scalar('val/Dice_1', dice[1], epoch)
        self.writer.add_scalar('val/Dice_2', dice[2], epoch)
        print("Dice:{}".format(dice))

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #15
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}

        parameters.set_saved_parafile_path(args.para)
        patch_w = parameters.get_digit_parameters("", "train_patch_width", None, 'int')
        patch_h = parameters.get_digit_parameters("", "train_patch_height", None, 'int')
        overlay_x = parameters.get_digit_parameters("", "train_pixel_overlay_x", None, 'int')
        overlay_y = parameters.get_digit_parameters("", "train_pixel_overlay_y", None, 'int')
        crop_height = parameters.get_digit_parameters("", "crop_height", None, 'int')
        crop_width = parameters.get_digit_parameters("", "crop_width", None, 'int')

        dataset = RemoteSensingImg(args.dataroot, args.list, patch_w, patch_h, overlay_x, overlay_y)

        #train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
        #                                           num_workers=args.workers, shuffle=True)
        train_length = int(len(dataset) * 0.9)
        validation_length = len(dataset) - train_length
	#print ("totol data len is %d , train_length is %d"%(len(train_loader),train_length))	
        [self.train_dataset, self.val_dataset] = torch.utils.data.random_split(dataset, (train_length, validation_length))
        print("len of train dataset is %d and val dataset is %d and total datalen is %d"%(len(self.train_dataset),len(self.val_dataset),len(dataset)))
        self.train_loader=torch.utils.data.DataLoader(self.train_dataset, batch_size=args.batch_size,num_workers=args.workers, shuffle=True,drop_last=True)
        self.val_loader=torch.utils.data.DataLoader(self.val_dataset, batch_size=args.batch_size,num_workers=args.workers, shuffle=True,drop_last=True)
        print("len of train loader is %d and val loader is %d"%(len(self.train_loader),len(self.val_loader)))
	#self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
	
        # Define network
        model = DeepLab(num_classes=1,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)




        # whether to use class balanced weights
        # if args.use_balanced_weights:
        #     classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
        #     if os.path.isfile(classes_weights_path):
        #         weight = np.load(classes_weights_path)
        #     else:
        #         weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
        #     weight = torch.from_numpy(weight.astype(np.float32))
        # else:
        #     weight = None





        # Define Criterion
        #self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)


        self.criterion=nn.BCELoss()

        if args.cuda:
            self.criterion=self.criterion.cuda()


        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(2)
        # Define lr scheduler
        print("lenght of train_loader is %d"%(len(self.train_loader)))
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            #self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {}) with best mIoU {}"
                  .format(args.resume, checkpoint['epoch'], checkpoint['best_pred']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
            self.best_pred=0

    def training(self, epoch):
        train_start_time=time.time()
        train_loss = 0.0
        self.model.train()
        #tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        print("start training at epoch %d, with the training length of %d"%(epoch,num_img_tr))
        for i, (x, y) in enumerate(self.train_loader):
            start_time=time.time()
            image, target = x, y
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            end_time=time.time()
            #tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            print('[The loss for iteration %d is %.3f and the time used is %.3f]'%(i+num_img_tr*epoch,loss.item(),end_time-start_time))
            # Show 10 * 3 inference results each epoch
            # if i % (num_img_tr // 10) == 0:
            #     global_step = i + num_img_tr * epoch
            #     self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        train_end_time=time.time()
        print('[Epoch: %d, numImages: %5d, time used : %.3f hour]' % (epoch, i * self.args.batch_size + image.data.shape[0],(train_end_time-train_start_time)/3600))
        print('Loss: %.3f' % (train_loss/len(self.train_loader)))
	
        with open(self.args.checkname+".train_out.txt", 'a') as log:
            out_massage='[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])
            log.writelines(out_massage+'\n')
            out_massage='Loss: %.3f' % (train_loss/len(self.train_loader))
            log.writelines(out_massage+'\n')
        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        time_val_start=time.time()
        self.model.eval()
        self.evaluator.reset()
        #tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, (x, y) in enumerate(self.val_loader):
            image, target = x,y
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            #tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            #pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            print("validate on the %d patch of total %d patch"%(i,len(self.val_loader)))
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        time_val_end=time.time()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d, time used: %.3f hour]' % (epoch, len(self.val_loader), (time_val_end-time_val_start)/3600))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Validation Loss: %.3f' % (test_loss/len((self.val_loader))))

        with open(self.args.checkname+".train_out.txt", 'a') as log:
            out_message='Validation:'
            log.writelines(out_message+'\n')
            out_message="Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)
            log.writelines(out_message+'\n')
            out_message='Validation Loss: %.3f' % (test_loss/len((self.val_loader)))
            log.writelines(out_message+'\n')
        new_pred = mIoU

        if new_pred > self.best_pred:
            print("saveing model")
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best,self.args.checkname)
            return False
        else:
            return True
Пример #16
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True}

        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader = make_data_loader(args, **kwargs)

        # Define network
        model = AutoStereo(maxdisp = self.args.max_disp, 
                           Fea_Layers=self.args.fea_num_layers, Fea_Filter=self.args.fea_filter_multiplier, 
                           Fea_Block=self.args.fea_block_multiplier, Fea_Step=self.args.fea_step, 
                           Mat_Layers=self.args.mat_num_layers, Mat_Filter=self.args.mat_filter_multiplier, 
                           Mat_Block=self.args.mat_block_multiplier, Mat_Step=self.args.mat_step)

        optimizer_F = torch.optim.SGD(
                model.feature.weight_parameters(), 
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )        
        optimizer_M = torch.optim.SGD(
                model.matching.weight_parameters(), 
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )

 
        self.model, self.optimizer_F, self.optimizer_M = model, optimizer_F, optimizer_M       
        self.architect_optimizer_F = torch.optim.Adam(self.model.feature.arch_parameters(),
                                                    lr=args.arch_lr, betas=(0.9, 0.999),
                                                    weight_decay=args.arch_weight_decay)

        self.architect_optimizer_M = torch.optim.Adam(self.model.matching.arch_parameters(),
                                                    lr=args.arch_lr, betas=(0.9, 0.999),
                                                    weight_decay=args.arch_weight_decay)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loaderA), min_lr=args.min_lr)
        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model).cuda()

        # Resuming checkpoint
        self.best_pred = 100.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    if k.find('module') != -1:
                        print(1)
                        pdb.set_trace()
                        name = k[7:]  # remove 'module.' of dataparallel
                        new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                pdb.set_trace()
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1:#or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])


            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer_M.state_dict(), checkpoint['optimizer_M'])
                copy_state_dict(self.optimizer_F.state_dict(), checkpoint['optimizer_F'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        print('Total number of model parameters : {}'.format(sum([p.data.nelement() for p in self.model.parameters()])))
        print('Number of Feature Net parameters: {}'.format(sum([p.data.nelement() for p in self.model.module.feature.parameters()])))
        print('Number of Matching Net parameters: {}'.format(sum([p.data.nelement() for p in self.model.module.matching.parameters()])))


    def training(self, epoch):
        train_loss = 0.0
        valid_iteration = 0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)

        for i, batch in enumerate(tbar):
            input1, input2, target = Variable(batch[0],requires_grad=True), Variable(batch[1], requires_grad=True), (batch[2])
            if self.args.cuda:
                input1 = input1.cuda()
                input2 = input2.cuda()
                target = target.cuda()

            target=torch.squeeze(target,1)
            mask = target < self.args.max_disp
            mask.detach_()
            valid = target[mask].size()[0]
            if valid > 0:
                self.scheduler(self.optimizer_F, i, epoch, self.best_pred)
                self.scheduler(self.optimizer_M, i, epoch, self.best_pred)
                self.optimizer_F.zero_grad()
                self.optimizer_M.zero_grad()
            
                output = self.model(input1, input2) 
                loss = F.smooth_l1_loss(output[mask], target[mask], reduction='mean')
                loss.backward()            
                self.optimizer_F.step()     
                self.optimizer_M.step()   

                if epoch >= self.args.alpha_epoch:
                    print("Start searching architecture!...........")
                    search = next(iter(self.train_loaderB))
                    input1_search, input2_search, target_search = Variable(search[0],requires_grad=True), Variable(search[1], requires_grad=True), (search[2])
                    if self.args.cuda:
                        input1_search = input1_search.cuda()
                        input2_search = input2_search.cuda()
                        target_search = target_search.cuda()

                    target_search=torch.squeeze(target_search,1)
                    mask_search = target_search < self.args.max_disp
                    mask_search.detach_()

                    self.architect_optimizer_F.zero_grad()
                    self.architect_optimizer_M.zero_grad()
                    output_search = self.model(input1_search, input2_search)
                    arch_loss = F.smooth_l1_loss(output_search[mask_search], target_search[mask_search], reduction='mean')

                    arch_loss.backward()            
                    self.architect_optimizer_F.step() 
                    self.architect_optimizer_M.step()   

                train_loss += loss.item()
                valid_iteration += 1
                tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
                self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            #Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image_stereo(self.writer, input1, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print("=== Train ===> Epoch :{} Error: {:.4f}".format(epoch, train_loss/valid_iteration))
        print(self.model.module.feature.alphas)

        #save checkpoint every epoch
        is_best = False
        if torch.cuda.device_count() > 1:
           state_dict = self.model.module.state_dict()
        else:
           state_dict = self.model.state_dict()
        self.saver.save_checkpoint({
               'epoch': epoch + 1,
               'state_dict': state_dict,
               'optimizer_F': self.optimizer_F.state_dict(),
               'optimizer_M': self.optimizer_M.state_dict(),
               'best_pred': self.best_pred,
        }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch))

    def validation(self, epoch):
        self.model.eval()
        
        epoch_error = 0
        three_px_acc_all = 0
        valid_iteration = 0

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, batch in enumerate(tbar):
            input1, input2, target = Variable(batch[0],requires_grad=False), Variable(batch[1], requires_grad=False), Variable(batch[2], requires_grad=False)
            if self.args.cuda:
                input1 = input1.cuda()
                input2 = input2.cuda()
                target = target.cuda()

            target=torch.squeeze(target,1)
            mask = target < self.args.max_disp
            mask.detach_()
            valid = target[mask].size()[0]

            if valid>0:
                with torch.no_grad():
                    output = self.model(input1, input2)

                    error = torch.mean(torch.abs(output[mask] - target[mask]))
                    epoch_error += error.item()

                    valid_iteration += 1

                    #computing 3-px error#                
                    pred_disp = output.cpu().detach()                                                                                                                          
                    true_disp = target.cpu().detach()
                    disp_true = true_disp
                    index = np.argwhere(true_disp<opt.max_disp)
                    disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(true_disp[index[0][:], index[1][:], index[2][:]]-pred_disp[index[0][:], index[1][:], index[2][:]])
                    correct = (disp_true[index[0][:], index[1][:], index[2][:]] < 1)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)      
                    three_px_acc = 1-(float(torch.sum(correct))/float(len(index[0])))

                    three_px_acc_all += three_px_acc
                    print("===> Test({}/{}): Error(EPE): ({:.4f} {:.4f})".format(i, len(self.val_loader), error.item(),three_px_acc))

        self.writer.add_scalar('val/EPE', epoch_error/valid_iteration, epoch)
        self.writer.add_scalar('val/D1_all', three_px_acc_all/valid_iteration, epoch)

        print("===> Test: Avg. Error: ({:.4f} {:.4f})".format(epoch_error/valid_iteration, three_px_acc_all/valid_iteration))


        # save model
        new_pred = epoch_error/valid_iteration # three_px_acc_all/valid_iteration
        if new_pred < self.best_pred: 
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer_F': self.optimizer_F.state_dict(),
                'optimizer_M': self.optimizer_M.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #17
0
class Tester(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        self.classNames = ['Background','Aeroplane',       'Bicycle',  'Bird',          'Boat', 'Bottle',        'Bus',\
                                  'Car',      'Cat',         'Chair',   'Cow', 'Dinning Table',    'Dog',      'Horse',\
                            'Motorbike',   'Person', 'Pottled Plant', 'Sheep',          'Sofa',  'Train', 'TV monitor']
        colorMap = [(   0,   0,   0),(0.50,   0,   0),(   0, 0.50,  0),(0.50, 0.50,  0),(   0,    0, 0.5),(0.50,    0, 0.5),(   0, 0.50, 0.5),\
                    (0.50, 0.5, 0.5),(0.25,   0,   0),(0.75,    0,  0),(0.25, 0.50,  0),(0.75, 0.50,   0),(0.25,    0, 0.5),(0.75,    0, 0.5),\
                    (0.25, 0.5, 0.5),(0.25, 0.5, 0.5),(   0, 0.25,  0),(0.50, 0.25,  0),(   0, 0.75,   0),(0.50, 0.75,   0),(   0, 0.25, 0.5)]

        self.colorMap = color_map(256, normalized=True)

        self.cmap = colors.ListedColormap(colorMap)
        bounds = [
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
            19, 20, 21
        ]
        self.norm = colors.BoundaryNorm(bounds, self.cmap.N)

        model = waspnet(num_classes=self.nclass,backbone=args.backbone,output_stride=args.out_stride,\
                        sync_bn=args.sync_bn,freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        self.evaluatorCRF = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            p = checkpoint['state_dict']
            if args.cuda:
                prefix = 'invalid'
                state_dict = self.model.module.state_dict()
                model_dict = {}
                for k, v in p.items():
                    if k in state_dict:
                        if not k.startswith(prefix):
                            model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                if not self.args.dataset == 'cityscapes':
                    self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            print("Best mIOU = " + str(self.best_pred))

    def test_save(self):
        epoch = 1
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.test_loader, desc='\r')

        w1 = 3
        w2 = 3
        Sa = 30
        Sb = 3
        Sg = 3

        postprocess = DenseCRF(iter_max=10,
                               pos_w=w1,
                               bi_w=w2,
                               bi_xy_std=Sa,
                               bi_rgb_std=Sb,
                               pos_xy_std=Sg)

        for i, sample in enumerate(tbar):
            image, image_path = sample['image'], sample['path']

            _, _, H, W = image.cpu().numpy().shape
            if self.args.cuda:
                image = image.cuda()

            with torch.no_grad():
                output = self.model(image)

            for j, (img, logit,
                    imPath) in enumerate(zip(image, output, image_path)):
                filename = os.path.join("output/", str(j) + ".npy")

                _, H, W = logit.shape

                original = logit.cpu().numpy()

                img = img.cpu().numpy()

                logit = torch.FloatTensor(logit.cpu().numpy())[None, ...]
                logit = F.interpolate(logit,
                                      size=(H, W),
                                      mode="bilinear",
                                      align_corners=False)
                prob = F.softmax(logit, dim=1)[0].cpu().numpy()
                img = img.astype(np.uint8).transpose(1, 2, 0)

                prob = postprocess(img, prob)

                label = np.argmax(prob, axis=0)

                label[label == 21] = 255
                out_image = Image.fromarray(label.squeeze().astype('uint8'))
                out_image.putpalette(self.colorMap)
                if self.args.dataset == 'pascal':
                    out_image.save(
                        "output/test/results/VOC2012/Segmentation/comp5_test_cls/"
                        + imPath[36:-4] + ".png")
                elif self.args.dataset == 'cityscapes':
                    out_image.save("output/Cityscapes/test/" + imPath[34:-4] +
                                   ".png")

    def test(self):
        epoch = 1
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        overallScore = 0.0
        overallScore_1 = 0.0

        # w1 = POS_W      =    [3:6]
        # w2 = BI_W       =        3
        # Sa = BI_XY_STD  = [30:100]
        # Sb = BI_RGB_STD =    [3:6]
        # Sg = POS_XY_STD =        3

        w1 = 3
        w2 = 3
        Sa = 30
        Sb = 3
        Sg = 3

        postprocess = DenseCRF(iter_max=10,
                               pos_w=w1,
                               bi_w=w2,
                               bi_xy_std=Sa,
                               bi_rgb_std=Sb,
                               pos_xy_std=Sg)

        start = time.time()
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            _, _, H, W = image.cpu().numpy().shape
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)

            for j, (img, logit,
                    gt_label) in enumerate(zip(image, output, target)):
                filename = os.path.join("output/", str(j) + ".npy")

                # Pixel Labeling
                _, H, W = logit.shape

                original = logit.cpu().numpy()

                img = img.cpu().numpy()
                gt_label = gt_label.cpu().numpy()

                logit = torch.FloatTensor(logit.cpu().numpy())[None, ...]
                logit = F.interpolate(logit,
                                      size=(H, W),
                                      mode="bilinear",
                                      align_corners=False)
                prob = F.softmax(logit, dim=1)[0].cpu().numpy()
                img = img.astype(np.uint8).transpose(1, 2, 0)

                prob = postprocess(img, prob)

                label = np.argmax(prob, axis=0)

                self.evaluatorCRF.add_batch(gt_label, label)

                score = scores(gt_label, label, n_class=self.nclass)
                overallScore += score['Mean IoU']

            mIoU_CRF = self.evaluatorCRF.Mean_Intersection_over_Union()

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)

            mIoU = self.evaluator.Mean_Intersection_over_Union()

        mIoU = self.evaluator.Mean_Intersection_over_Union()

        print("w1 =  ", w1)
        print("w2 =  ", w2)
        print("Sa = ", Sa)
        print("Sb =  ", Sb)
        print("Sg =  ", Sg)

        print("Final-PRE -CRF  =" + str(mIoU))
        print("Final-POST-CRF  =" + str(mIoU_CRF))

        end = time.time()
class Trainer(object):
    def __init__(self, args, dataloaders, mc_dropout):
        self.args = args
        self.mc_dropout = mc_dropout
        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataloaders

    def setup_saver_and_summary(self,
                                num_current_labeled_samples,
                                samples,
                                experiment_group=None,
                                regions=None):

        self.saver = ActiveSaver(self.args,
                                 num_current_labeled_samples,
                                 experiment_group=experiment_group)
        self.saver.save_experiment_config()
        self.saver.save_active_selections(samples, regions)
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

    def initialize(self):

        args = self.args

        if args.architecture == 'deeplab':
            print('Using Deeplab')
            model = DeepLab(num_classes=self.nclass,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)
            train_params = [{
                'params': model.get_1x_lr_params(),
                'lr': args.lr
            }, {
                'params': model.get_10x_lr_params(),
                'lr': args.lr * 10
            }]
        elif args.architecture == 'enet':
            print('Using ENet')
            model = ENet(num_classes=self.nclass,
                         encoder_relu=True,
                         decoder_relu=True)
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        elif args.architecture == 'fastscnn':
            print('Using FastSCNN')
            model = FastSCNN(3, self.nclass)
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        if args.optimizer == 'SGD':
            optimizer = torch.optim.SGD(train_params,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay,
                                        nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            optimizer = torch.optim.Adam(train_params,
                                         weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        if args.use_balanced_weights:
            weight = calculate_weights_labels(args.dataset, self.train_loader,
                                              self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.evaluator = Evaluator(self.nclass)

        if args.use_lr_scheduler:
            self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                          args.epochs, len(self.train_loader))
        else:
            self.scheduler = None

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        self.best_pred = 0.0

    def training(self, epoch):

        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_loader)
        tbar = tqdm(self.train_loader, desc='\r')

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            if self.scheduler:
                self.scheduler(self.optimizer, i, epoch, self.best_pred)
                self.writer.add_scalar('train/learning_rate',
                                       self.scheduler.current_lr,
                                       i + num_img_tr * epoch)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)
        print('BestPred: %.3f' % self.best_pred)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

        return train_loss

    def validation(self, epoch):

        self.model.eval()
        self.evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        visualization_index = int(random.random() * len(self.val_loader))
        vis_img = None
        vis_tgt = None
        vis_out = None

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)

            if i == visualization_index:
                vis_img = image
                vis_tgt = target
                vis_out = output

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)

            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        is_best = False
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred

        # save every validation model (overwrites)
        self.saver.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

        return test_loss, mIoU, Acc, Acc_class, FWIoU, [
            vis_img, vis_tgt, vis_out
        ]
Пример #19
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        if args.distributed:
            if args.local_rank ==0:
                self.saver = Saver(args)
                self.saver.save_experiment_config()
                # Define Tensorboard Summary
                self.summary = TensorboardSummary(self.saver.experiment_dir)
                self.writer = self.summary.create_summary()
        else:
            self.saver = Saver(args)
            self.saver.save_experiment_config()
            # Define Tensorboard Summary
            self.summary = TensorboardSummary(self.saver.experiment_dir)
            self.writer = self.summary.create_summary()

        # PATH = args.path
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        # self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = SCNN(nclass=self.nclass,backbone=args.backbone,output_stride=args.out_stride,cuda = args.cuda,extension=args.ext)


        # Define Optimizer
        # optimizer = torch.optim.SGD(model.parameters(),args.lr, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(model.parameters(), args.lr,weight_decay=args.weight_decay)

        # model, optimizer = amp.initialize(model,optimizer,opt_level="O1")

        # Define Criterion
        weight = None
        # criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = FocalLoss(gamma=0, alpha=[0.2, 0.98], img_size=512*512)
        self.criterion1 = FocalLoss(gamma=5, alpha=[0.2, 0.98], img_size=512 * 512)
        self.criterion2 = disc_loss(delta_v=0.5, delta_d=3.0, param_var=1.0, param_dist=1.0,
                                    param_reg=0.001, EMBEDDING_FEATS_DIMS=21,image_shape=[512,512])

        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.val_loader),local_rank=args.local_rank)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()
            if args.distributed:
                self.model = DistributedDataParallel(self.model)
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)


        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            filename = 'checkpoint.pth.tar'
            args.resume = os.path.join(self.saver.experiment_dir, filename)
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # if args.cuda:
            #     self.model.module.load_state_dict(checkpoint['state_dict'])
            # else:
            self.model.load_state_dict(checkpoint['state_dict'])
            # if not args.ft:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        max_instances = 1
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target, ins_target = sample['image'], sample['bin_label'], sample['label']
            # _target = target.cpu().numpy()
            # if np.max(_target) > max_instances:
            #     max_instances = np.max(_target)
            #     print(max_instances)

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)

            # if i % 10==0:
            #     misc.imsave('/mfc/user/1623600/.temp6/train_{:s}_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),np.transpose(image[0].cpu().numpy(),(1,2,0)))
            #     os.chmod('/mfc/user/1623600/.temp6/train_{:s}_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),0o777)


            # self.criterion = DataParallelCriterion(self.criterion)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, ins_target)

            reg_lambda = 0.01


            loss = loss1 + 10*loss2
            # loss = loss1
            output=output[1]
            # loss.back
            # with amp.scale_loss(loss, self.optimizer) as scaled_loss:
            #     scaled_loss.backward()

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            if self.args.distributed:
                if self.args.local_rank == 0:
                    self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            else:
                self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)


            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr / 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.distributed:
                    if self.args.local_rank == 0:
                        self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
                else:
                    self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        if self.args.distributed:
            if self.args.local_rank == 0:
                self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        else:
            self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)

        if self.args.local_rank == 0:
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print('Loss: %.3f' % train_loss)

        # if self.args.distributed:
        #     if self.args.local_rank == 0:
        #         if self.args.no_val:
        #             # save checkpoint every epoch
        #             is_best = False
        #             self.saver.save_checkpoint({
        #                 'epoch': epoch + 1,
        #                 'state_dict': self.model.module.state_dict(),
        #                 'optimizer': self.optimizer.state_dict(),
        #                 'best_pred': self.best_pred,
        #             }, is_best)
        #     else:
        #         if self.args.no_val:
        #             # save checkpoint every epoch
        #             is_best = False
        #             self.saver.save_checkpoint({
        #                 'epoch': epoch + 1,
        #                 'state_dict': self.model.module.state_dict(),
        #                 'optimizer': self.optimizer.state_dict(),
        #                 'best_pred': self.best_pred,
        #             }, is_best)



    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target = sample['image'], sample['bin_label']
            a= target.numpy()
            aa_max = np.max(a)
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion1(output, target)
            test_loss += loss.item()
            instance_seg = output[0].data.cpu().numpy()
            instance_seg = np.squeeze(instance_seg[0])
            instance_seg = np.transpose(instance_seg, (1, 2, 0))
            output = output[1]
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            if i % 30==0:
                misc.imsave('/mfc/user/1623600/.temp6/{:s}_val_epoch:{}_i:{}.png'
                            .format(str(self.args.distributed),epoch,i),
                            np.transpose(image[0].cpu().numpy(),(1,2,0))+3*np.asarray(np.stack((pred[0],pred[0],pred[0]),axis=-1),dtype=np.uint8))
                os.chmod('/mfc/user/1623600/.temp6/{:s}_val_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),0o777)
                temp_instance_seg = np.zeros_like(np.transpose(image[0].cpu().numpy(),(1,2,0)))
                for j in range(21):
                    if j<7:
                        temp_instance_seg[:, :, 0] += instance_seg[:, :, j]
                    elif j<14:
                        temp_instance_seg[:, :, 1] += instance_seg[:, :, j]
                    else:
                        temp_instance_seg[:, :, 2] += instance_seg[:, :, j]

                for k in range(3):
                    temp_instance_seg[:, :, k] = self.minmax_scale(temp_instance_seg[:, :, k])

                instance_seg = np.array(temp_instance_seg, np.uint8)


                misc.imsave('/mfc/user/1623600/.temp6/emb_{:s}_val_epoch:{}_i:{}.png'
                            .format(str(self.args.distributed), epoch, i),instance_seg[...,:3])
                os.chmod(
                    '/mfc/user/1623600/.temp6/emb_{:s}_val_epoch:{}_i:{}.png'.format(str(self.args.distributed), epoch, i),
                    0o777)



        if self.args.distributed:
            if self.args.local_rank == 0:
                # Fast test during the training
                Acc = self.evaluator.Pixel_Accuracy()
                Acc_class = self.evaluator.Pixel_Accuracy_Class()
                mIoU = self.evaluator.Mean_Intersection_over_Union()
                F0 = self.evaluator.F0()
                self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
                self.writer.add_scalar('val/mIoU', mIoU, epoch)
                self.writer.add_scalar('val/Acc', Acc, epoch)
                self.writer.add_scalar('val/Acc_class', Acc_class, epoch)

                if self.args.local_rank == 0:
                    print('Validation:')
                    print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
                    print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
                    print('Loss: %.3f' % test_loss)

                new_pred = F0
                if new_pred > self.best_pred:
                    is_best = True
                    self.best_pred = new_pred
                    self.saver.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)
        else:
            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            F0 = self.evaluator.F0()
            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/mIoU', mIoU, epoch)
            self.writer.add_scalar('val/Acc', Acc, epoch)
            self.writer.add_scalar('val/Acc_class', Acc_class, epoch)

            if self.args.local_rank == 0:
                print('Validation:')
                print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
                print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
                print('Loss: %.3f' % test_loss)

            new_pred = F0
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def minmax_scale(self,input_arr):
        """

        :param input_arr:
        :return:
        """
        min_val = np.min(input_arr)
        max_val = np.max(input_arr)

        output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

        return output_arr
Пример #20
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define TensorBoard summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        if not args.test:
            self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        norm = bn

        # Define Network
        model = Model(args, self.nclass)

        train_params = [{'params': model.parameters(), 'lr': args.lr}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        criterion = nn.CrossEntropyLoss()

        self.model, self.optimizer, self.criterion = model, optimizer, criterion

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using CUDA
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("{}: No such checkpoint exists".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            if args.cuda:
                pretrained_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                print("Please use CUDA")
                raise NotImplementedError

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("Loading {} (epoch {}) successfully done".format(
                args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0

        # layer wise freezing
        self.histories = []
        self.history = {}
        self.isTrained = False
        self.freeze_count = 0
        self.total_count = 0
        for i in model.parameters():
            self.total_count += 1

    def train(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        if self.isTrained: return
        for i, sample in enumerate(tbar):
            image, target = sample['image'].cuda(), sample['label'].cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print("[epoch: %d, loss: %.3f]" % (epoch, train_loss))

        M = 0

        if self.args.freeze or self.args.time:
            for n, i in self.model.named_parameters():
                self.history[str(epoch) + n] = i.cpu().detach()
            if epoch >= 1:
                for n, i in self.model.named_parameters():
                    if not 'conv' in n or not 'layer' in n: continue
                    dif = self.history[str(epoch) +
                                       n] - self.history[str(epoch - 1) + n]
                    m = np.abs(dif).mean()
                    if m < 1e-04:
                        M += 1
                        if i.requires_grad and self.args.freeze:
                            i.requires_grad = False
                            self.freeze_count += 1
                            if not self.args.decomp: continue
                            name = n.split('.')
                            if name[0] == 'layer1': layer = self.model.layer1
                            elif name[0] == 'layer2': layer = self.model.layer2
                            elif name[0] == 'layer3': layer = self.model.layer3
                            elif name[0] == 'layer4': layer = self.model.layer4
                            else: continue
                            conv = layer._modules[int(name[1])]
                            dec = tucker_decomposition_conv_layer(conv)
                            layer._modules[int(name[1])] = dec
                if self.freeze_count == self.total_count: self.isTrained = True
            if not self.args.freeze and self.args.time:
                self.histories.append((epoch, M))
            else:
                self.histories.append((epoch, self.freeze_count))

        if self.args.no_val and not self.args.time:
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def val(self, epoch):
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')

        self.model.eval()

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)

            #
            test_loss += loss.item()

            # top accuracy record
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))

            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            #self.evaluator.add_batch(target, pred)

        # Fast test during the training
        _top1 = top1.avg
        _top5 = top5.avg
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/top1', _top1, epoch)
        self.writer.add_scalar('val/top5', _top5, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Top-1: %.3f, Top-5: %.3f" % (_top1, _top5))
        print('Loss: %.3f' % test_loss)

        new_pred = _top1
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = float(new_pred)
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class trainNew(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        cell_path_d = os.path.join(args.saved_arch_path, 'genotype_device.npy')
        cell_path_c = os.path.join(args.saved_arch_path, 'genotype_cloud.npy')
        network_path_space = os.path.join(args.saved_arch_path, 'network_path_space.npy')

        new_cell_arch_d = np.load(cell_path_d)
        new_cell_arch_c = np.load(cell_path_c)
        new_network_arch = np.load(network_path_space)
        new_network_arch = [1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2]
        # Define network
        model = new_cloud_Model(network_arch= new_network_arch,
                         cell_arch_d = new_cell_arch_d,
                         cell_arch_c = new_cell_arch_c,
                         num_classes=self.nclass,
                         device_num_layers=6)

        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
#        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
#                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
#
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator_device = Evaluator(self.nclass)
        self.evaluator_cloud = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader)) #TODO: use min_lr ?

        # Using cuda
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape, dtype=module.running_var.dtype,
                                           device=module.running_var.device), requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape, dtype=module.running_var.dtype,
                                            device=module.running_var.device), requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer, self.architect_optimizer] = amp.initialize(
                self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level,
                keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) >1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            device_output, cloud_output = self.model(image)
            
            device_loss = self.criterion(device_output, target)
            cloud_loss = self.criterion(cloud_output, target)
            loss = device_loss + cloud_loss
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            self.optimizer.step()
            train_loss += loss.item()

            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            if i %50 == 0:
                self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % 100 == 0:
                output = (device_output + cloud_output)/2
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator_device.reset()
        self.evaluator_cloud.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                device_output, cloud_output = self.model(image)
            device_loss = self.criterion(device_output, target)
            cloud_loss = self.criterion(cloud_output, target)
            loss = device_loss + cloud_loss
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred_d = device_output.data.cpu().numpy()
            pred_c = cloud_output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred_d = np.argmax(pred_d, axis=1)
            pred_c = np.argmax(pred_c, axis=1)
            # Add batch sample into evaluator
            self.evaluator_device.add_batch(target, pred_d)
            self.evaluator_cloud.add_batch(target, pred_c)

        mIoU_d = self.evaluator_device.Mean_Intersection_over_Union()
        mIoU_c = self.evaluator_cloud.Mean_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/device/mIoU', mIoU_d, epoch)
        self.writer.add_scalar('val/cloud/mIoU', mIoU_c, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("device_mIoU:{}, cloud_mIoU: {}".format(mIoU_d, mIoU_c))
        print('Loss: %.3f' % test_loss)

        new_pred = (mIoU_d + mIoU_c)/2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #22
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def testing(self, epoch):
        self.model.eval()
        tbar = tqdm(self.test_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, _ = sample['image'], sample['label']
            if self.args.cuda:
                image = image.cuda()
            output = self.model(image)
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            for p in pred:
                plt.figure()
                plt.imshow(p)
                plt.show()
Пример #23
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        # kwargs = {'num_workers': args.workers, 'pin_memory': True}
        kwargs = {'num_workers': 0, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.nir:
            input_channels = 4
        else:
            input_channels = 3

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        in_channels=input_channels,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
                weight[1] = 4
                weight[2] = 2
                weight[0] = 1
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                # place_holder_target = target
                # place_holder_output = output
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def pred_single_image(self, path):
        self.model.eval()
        img_path = path
        lbl_path = os.path.join(
            os.path.split(os.path.split(path)[0])[0], 'lbl',
            os.path.split(path)[1])
        activations = collections.defaultdict(list)

        def save_activation(name, mod, input, output):
            activations[name].append(output.cpu())

        for name, m in self.model.named_modules():
            if type(m) == nn.ReLU:
                m.register_forward_hook(partial(save_activation, name))

        input = cv2.imread(path)
        label = cv2.imread(lbl_path)
        # bkg = cv2.createBackgroundSubtractorMOG2()
        # back = bkg.apply(input)
        # cv2.imshow('back', back)
        # cv2.waitKey()
        input = cv2.resize(input, (513, 513), interpolation=cv2.INTER_CUBIC)
        image = Image.open(img_path).convert('RGB')  # width x height x 3
        # _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _tmp = np.array(Image.open(img_path), dtype=np.uint8)
        _tmp[_tmp == 255] = 1
        _tmp[_tmp == 0] = 0
        _tmp[_tmp == 128] = 2
        _tmp = Image.fromarray(_tmp)

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=513),
            tr.Normalize(mean=mean, std=std),
            tr.ToTensor()
        ])
        sample = {'image': image, 'label': _tmp}
        sample = composed_transforms(sample)

        image, target = sample['image'], sample['label']

        image = torch.unsqueeze(image, dim=0)
        if self.args.cuda:
            image, target = image.cuda(), target.cuda()
        with torch.no_grad():
            output = self.model(image)
            # output = output.data.cpu().numpy().squeeze(0).transpose([1, 2, 0])

            # output = np.argmax(output, axis=2) * 255
            output = output.data.cpu().numpy()
            prediction = np.argmax(output, axis=1)
            prediction = np.squeeze(prediction, axis=0)
            prediction[prediction == 1] = 255
            if np.any(prediction == 2):
                prediction[prediction == 2] = 128
            if np.any(prediction == 1):
                prediction[prediction == 1] = 255
            print(np.unique(prediction))

        see = Analysis(activations, label=1, path=self.saver.experiment_dir)
        see.backtrace(output)
        # for key in keys:
        #
        #     see.visualize_tensor(see.image)
        # see.save_tensor(see.image, self.saver.experiment_dir)

        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'rgb.png'), input)
        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'lbl.png'), label)
        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'prediction.png'),
                    prediction)
        # pred = output.data.cpu().numpy()
        # target = target.cpu().numpy()
        # pred = np.argmax(pred, axis=1)
        # pred = np.reshape(pred, (513, 513))
        # # prediction = np.append(target, pred, axis=1)
        # prediction = pred
        #
        # rgb = np.zeros((prediction.shape[0], prediction.shape[1], 3))
        #
        # r = prediction.copy()
        # g = prediction.copy()
        # b = prediction.copy()
        #
        # g[g != 1] = 0
        # g[g == 1] = 255
        #
        # r[r != 2] = 0
        # r[r == 2] = 255
        # b = np.zeros(b.shape)
        #
        # rgb[:, :, 0] = b
        # rgb[:, :, 1] = g
        # rgb[:, :, 2] = r
        #
        # prediction = np.append(input, rgb.astype(np.uint8), axis=1)
        # result = np.append(input, prediction.astype(np.uint8), axis=1)
        # cv2.line(rgb, (513, 0), (513, 1020), (255, 255, 255), thickness=1)
        # cv2.line(rgb, (513, 0), (513, 1020), (255, 255, 255), thickness=1)
        # cv2.imwrite('/home/robot/git/pytorch-deeplab-xception/run/cropweed/deeplab-resnet/experiment_41/samples/synthetic_{}.png'.format(counter), prediction)
        # plt.imshow(see.weed_filter)
        # # cv2.waitKey()
        # plt.show()

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def testing(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.test_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)
            # Add batch sample into evaluator
            prediction = np.append(target, pred, axis=2)
            print(pred.shape)
            input = image[0, 0:3, :, :].cpu().numpy().transpose([1, 2, 0])
            # cv2.imshow('figure', prediction)
            # cv2.waitKey()

        # Fast test during the testing
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print(
            '[INFO] Network performance measures on the test dataset are as follows: \n '
            'mIOU: {} \n FWIOU: {} \n Class accuracy: {} \n Pixel Accuracy: {}'
            .format(mIoU, FWIoU, Acc_class, Acc))

        self.evaluator.per_class_accuracy()

    def explain_image(self, path, counter):
        self.model.eval()
        img_path = path
        lbl_path = os.path.join(
            os.path.split(os.path.split(path)[0])[0], 'lbl',
            os.path.split(path)[1])
        image = Image.open(img_path).convert('RGB')  # width x height x 3
        _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _tmp[_tmp == 255] = 1
        _tmp[_tmp == 0] = 0
        _tmp[_tmp == 128] = 2
        _tmp = Image.fromarray(_tmp)

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=513),
            tr.Normalize(mean=mean, std=std),
            tr.ToTensor()
        ])
        sample = {'image': image, 'label': _tmp}
        sample = composed_transforms(sample)

        image, target = sample['image'], sample['label']

        image = torch.unsqueeze(image, dim=0)
        # if self.args.cuda:
        #     image, target = image.cuda(), target.cuda()
        # with torch.no_grad():
        #     output = self.model(image)
        # inn_model = InnvestigateModel(self.model, lrp_exponent=1,
        #                               method="b-rule",
        #                               beta=0, epsilon=1e-6)
        #
        # inn_model.eval()
        # model_prediction, heatmap = inn_model.innvestigate(in_tensor=image)
        # model_prediction = np.argmax(model_prediction, axis=1)

        # def run_guided_backprop(net, image_tensor):
        #     return interpretation.guided_backprop(net, image_tensor, cuda=True, verbose=False, apply_softmax=False)
        #
        # def run_LRP(net, image_tensor):
        #     return inn_model.innvestigate(in_tensor=image_tensor, rel_for_class=1)
        print('hold')
Пример #24
0
class trainNew(object):
    def __init__(self, args):
        self.args = args
        """ Define Saver """
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        """ Define Tensorboard Summary """
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        """ Define Dataloader """
        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loader, self.val_loader, _, self.nclass = make_data_loader(
            args, **kwargs)

        self.criterion = nn.L1Loss()
        if args.network == 'searched-dense':
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab',
                                     'genotype.npy')
            cell_arch = np.load(cell_path)

            if self.args.C == 2:
                C_index = [5]
                network_arch = [1, 2, 2, 2, 3, 2, 2, 1, 1, 1, 1, 2]
                low_level_layer = 0
            elif self.args.C == 3:
                C_index = [3, 7]
                network_arch = [1, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3]
                low_level_layer = 0
            elif self.args.C == 4:
                C_index = [2, 5, 8]
                network_arch = [1, 2, 3, 3, 2, 3, 3, 3, 3, 3, 2, 2]
                low_level_layer = 0

            model = ADD(network_arch, C_index, cell_arch, self.nclass, args,
                        low_level_layer)

        elif args.network.startswith('autodeeplab'):
            network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab',
                                     'genotype.npy')
            cell_arch = np.load(cell_path)
            low_level_layer = 2
            if self.args.C == 2:
                C_index = [5]
            elif self.args.C == 3:
                C_index = [3, 7]
            elif self.args.C == 4:
                C_index = [2, 5, 8]

            if args.network == 'autodeeplab-dense':
                model = ADD(network_arch, C_index, cell_arch, self.nclass,
                            args, low_level_layer)

            elif args.network == 'autodeeplab-baseline':
                model = Baselin_Model(network_arch, C_index, cell_arch,
                                      self.nclass, args, low_level_layer)

        self.edm = EDM().cuda()
        optimizer = torch.optim.Adam(self.edm.parameters(), lr=args.lr)
        self.model, self.optimizer = model, optimizer

        if args.cuda:
            self.model = self.model.cuda()
        """ Resuming checkpoint """
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            """ if the weights are wrapped in module object we have to clean it """
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if (torch.cuda.device_count() > 1):
                    copy_state_dict(self.model.module.state_dict(),
                                    checkpoint['state_dict'])
                else:
                    copy_state_dict(self.model.state_dict(),
                                    checkpoint['state_dict'])

        if os.path.isfile('feature.npy'):
            train_feature = np.load('feature.npy')
            train_entropy = np.load('entropy.npy')
            train_set = TensorDataset(
                torch.tensor(train_feature),
                torch.tensor(train_entropy, dtype=torch.float))
            train_set = DataLoader(train_set,
                                   batch_size=self.args.train_batch,
                                   shuffle=True,
                                   pin_memory=True)
            self.train_set = train_set
        else:
            self.make_data(self.args.train_batch)

    def make_data(self, batch_size):
        self.model.eval()
        tbar = tqdm(self.train_loader, desc='\r')
        train_feature = []
        train_entropy = []
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output, feature = self.model.get_feature(image)
                train_entropy.append(normalized_shannon_entropy(output))
                train_feature.append(feature.cpu())

        train_feature = [t.numpy() for t in train_feature]
        np_entropy = np.array(train_entropy)
        np.save('feature', train_feature)
        np.save('entropy', train_entropy)
        train_set = TensorDataset(
            torch.tensor(train_feature, dtype=torch.float),
            torch.tensor(train_entropy, dtype=torch.float))
        train_set = DataLoader(train_set,
                               batch_size=batch_size,
                               shuffle=True,
                               pin_memory=True)
        self.train_set = train_set

    def training(self, epoch):
        train_loss = 0.0
        self.edm.train()
        tbar = tqdm(self.train_set)
        for i, (feature, entropy) in enumerate(tbar):
            if self.args.cuda:
                feature, entropy = feature.cuda(), entropy.cuda()
            output = self.edm(feature)
            loss = self.criterion(output, entropy)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d' % (epoch))
        print('Loss: %.3f' % train_loss)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())

        # Define Optimizer
        if args.optimizer == 'SGD':
            self.task_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_optimizer = torch.optim.SGD(d_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_inv_optimizer = torch.optim.SGD(f_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.c_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            self.task_optimizer = torch.optim.Adam(f_params + y_params, lr=args.lr)
            self.d_optimizer = torch.optim.Adam(d_params, lr=args.lr)
            self.d_inv_optimizer = torch.optim.Adam(f_params, lr=args.lr)
            self.c_optimizer = torch.optim.Adam(f_params+y_params, lr=args.lr)
        else:
            raise NotImplementedError

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\'+args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.train_loader, self.nclass, classes_weights_path, self.args.dataset)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.task_loss = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.domain_loss = DomainLosses(cuda=args.cuda).build_loss()
        self.ca_loss = ''

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
            if self.args.dataset == 'gtav':
                self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_task_loss = 0.0
        train_d_loss = 0.0
        train_d_inv_loss = 0.0
        self.backbone_model.train()
        self.assp_model.train()
        self.y_model.train()
        self.d_model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            if self.args.dataset == 'gtav':
                src_image,src_label = sample['image'], sample['label']
            else:
                src_image, src_label, tgt_image = sample['src_image'], sample['src_label'], sample['tgt_image']
            if self.args.cuda:
                if self.args.dataset != 'gtav':
                    src_image, src_label, tgt_image  = src_image.cuda(), src_label.cuda(), tgt_image.cuda()
                else:
                    src_image, src_label = src_image.cuda(), src_label.cuda()
            self.scheduler(self.task_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.d_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.d_inv_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.c_optimizer, i, epoch, self.best_pred)
            self.task_optimizer.zero_grad()
            self.d_optimizer.zero_grad()
            self.d_inv_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            # source image feature
            src_high_feature_0, src_low_feature = self.backbone_model(src_image)
            src_high_feature = self.assp_model(src_high_feature_0)
            src_output = F.interpolate(self.y_model(src_high_feature, src_low_feature), src_image.size()[2:], \
                                       mode='bilinear', align_corners=True)

            src_d_pred = self.d_model(src_high_feature)
            task_loss = self.task_loss(src_output, src_label)

            if self.args.dataset != 'gtav':
                # target image feature
                tgt_high_feature_0, tgt_low_feature = self.backbone_model(tgt_image)
                tgt_high_feature = self.assp_model(tgt_high_feature_0)
                tgt_output = F.interpolate(self.y_model(tgt_high_feature, tgt_low_feature), tgt_image.size()[2:], \
                                           mode='bilinear', align_corners=True)
                tgt_d_pred = self.d_model(tgt_high_feature)

                d_loss,d_acc = self.domain_loss(src_d_pred, tgt_d_pred)
                d_inv_loss, _ = self.domain_loss(tgt_d_pred, src_d_pred)
                loss = task_loss + d_loss + d_inv_loss
                loss.backward()
                self.task_optimizer.step()
                self.d_optimizer.step()
                self.d_inv_optimizer.step()
            else:
                d_acc = 0
                d_loss = torch.tensor(0.0)
                d_inv_loss = torch.tensor(0.0)
                loss = task_loss
                loss.backward()
                self.task_optimizer.step()

            train_task_loss += task_loss.item()
            train_d_loss += d_loss.item()
            train_d_inv_loss += d_inv_loss.item()
            train_loss += task_loss.item() + d_loss.item() + d_inv_loss.item()

            tbar.set_description('Train loss: %.3f t_loss: %.3f d_loss: %.3f , d_inv_loss: %.3f  d_acc: %.2f' \
                                 % (train_loss / (i + 1),train_task_loss / (i + 1),\
                                    train_d_loss / (i + 1), train_d_inv_loss / (i + 1), d_acc*100))

            self.writer.add_scalar('train/task_loss_iter', task_loss.item(), i + num_img_tr * epoch)
            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.dataset != 'gtav':
                    image = torch.cat([src_image,tgt_image],dim=0)
                    output = torch.cat([src_output,tgt_output],dim=0)
                else:
                    image = src_image
                    output = src_output
                self.summary.visualize_image(self.writer, self.args.dataset, image, src_label, output, global_step)


        self.writer.add_scalar('train/task_loss_epoch', train_task_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + src_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'backbone_model_state_dict': self.backbone_model.module.state_dict(),
                'assp_model_state_dict': self.assp_model.module.state_dict(),
                'y_model_state_dict': self.y_model.module.state_dict(),
                'd_model_state_dict': self.d_model.module.state_dict(),
                'task_optimizer': self.task_optimizer.state_dict(),
                'd_optimizer': self.d_optimizer.state_dict(),
                'd_inv_optimizer': self.d_inv_optimizer.state_dict(),
                'c_optimizer': self.c_optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            task_loss = self.task_loss(output, target)
            test_loss += task_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU,IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU

        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'backbone_model_state_dict': self.backbone_model.module.state_dict(),
                'assp_model_state_dict': self.assp_model.module.state_dict(),
                'y_model_state_dict': self.y_model.module.state_dict(),
                'd_model_state_dict': self.d_model.module.state_dict(),
                'task_optimizer': self.task_optimizer.state_dict(),
                'd_optimizer': self.d_optimizer.state_dict(),
                'd_inv_optimizer': self.d_inv_optimizer.state_dict(),
                'c_optimizer': self.c_optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion

        self.criterion = SegmentationLosses(cuda=args.cuda)
        self.model, self.optimizer = model, optimizer
        self.contexts = TemporalContexts(history_len=5)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning or in validation/test mode
        if args.ft or args.mode == "val" or args.mode == "test":
            args.start_epoch = 0
            self.best_pred = 0.0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)

        for i, sample in enumerate(tbar):
            image, region_prop, target = sample['image'], sample['rp'], sample[
                'label']
            if self.args.cuda:
                image, region_prop, target = image.cuda(), region_prop.cuda(
                ), target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image, region_prop)
            loss = self.criterion.CrossEntropyLoss(
                output,
                target,
                weight=torch.from_numpy(
                    calculate_weights_batch(sample,
                                            self.nclass).astype(np.float32)))
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            pred = output.clone().data.cpu()
            pred_softmax = F.softmax(pred, dim=1).numpy()
            pred = np.argmax(pred.numpy(), axis=1)

            # Plot prediction every 20th iter
            if i % (num_img_tr // 20) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.vis_grid(self.writer,
                                      self.args.dataset,
                                      image.data.cpu().numpy()[0],
                                      target.data.cpu().numpy()[0],
                                      pred[0],
                                      region_prop.data.cpu().numpy()[0],
                                      pred_softmax[0],
                                      global_step,
                                      split="Train")

        self.writer.add_scalar('train/total_loss_epoch',
                               train_loss / num_img_tr, epoch)
        print('Loss: {}'.format(train_loss / num_img_tr))

        if self.args.no_val or self.args.save_all:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                },
                is_best,
                filename='checkpoint_' + str(epoch + 1) + '_.pth.tar')

    def validation(self, epoch):

        if self.args.mode == "train" or self.args.mode == "val":
            loader = self.val_loader
        elif self.args.mode == "test":
            loader = self.test_loader

        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(loader, desc='\r')

        test_loss = 0.0
        idr_thresholds = [0.20, 0.30, 0.40, 0.50, 0.60, 0.65]

        num_itr = len(loader)

        for i, sample in enumerate(tbar):
            image, region_prop, target = sample['image'], sample['rp'], sample[
                'label']
            # orig_region_prop = region_prop.clone()
            # region_prop = self.contexts.temporal_prop(image.numpy(),region_prop.numpy())

            if self.args.cuda:
                image, region_prop, target = image.cuda(), region_prop.cuda(
                ), target.cuda()
            with torch.no_grad():
                output = self.model(image, region_prop)

            # loss = self.criterion.CrossEntropyLoss(output,target,weight=torch.from_numpy(calculate_weights_batch(sample,self.nclass).astype(np.float32)))
            # test_loss += loss.item()
            # tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))

            output = output.detach().data.cpu()
            pred_softmax = F.softmax(output, dim=1).numpy()
            pred = np.argmax(pred_softmax, axis=1)
            target = target.cpu().numpy()
            image = image.cpu().numpy()
            region_prop = region_prop.cpu().numpy()
            # orig_region_prop = orig_region_prop.numpy()

            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # Append buffer with original context(before temporal propagation)
            # self.contexts.append_buffer(image[0],orig_region_prop[0],pred[0])

            global_step = i + num_itr * epoch
            self.summary.vis_grid(self.writer,
                                  self.args.dataset,
                                  image[0],
                                  target[0],
                                  pred[0],
                                  region_prop[0],
                                  pred_softmax[0],
                                  global_step,
                                  split="Validation")

        # Fast test during the training
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        recall, precision = self.evaluator.pdr_metric(class_id=2)
        idr_avg = np.array([
            self.evaluator.get_idr(class_value=2, threshold=value)
            for value in idr_thresholds
        ])
        false_idr = self.evaluator.get_false_idr(class_value=2)
        instance_iou = self.evaluator.get_instance_iou(threshold=0.20,
                                                       class_value=2)

        # self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Recall/per_epoch', recall, epoch)
        self.writer.add_scalar('IDR/per_epoch(0.20)', idr_avg[0], epoch)
        self.writer.add_scalar('IDR/avg_epoch', np.mean(idr_avg), epoch)
        self.writer.add_scalar('False_IDR/epoch', false_idr, epoch)
        self.writer.add_scalar('Instance_IOU/epoch', instance_iou, epoch)
        self.writer.add_histogram(
            'Prediction_hist',
            self.evaluator.pred_labels[self.evaluator.gt_labels == 2], epoch)

        print('Validation:')
        # print('Loss: %.3f' % test_loss)
        # print('Recall/PDR:{}'.format(recall))
        print('IDR:{}'.format(idr_avg[0]))
        print('False Positive Rate: {}'.format(false_idr))
        print('Instance_IOU: {}'.format(instance_iou))

        if self.args.mode == "train":
            new_pred = mIoU
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)

        else:
            pass
Пример #27
0
def main():
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()
    best_pred = 0

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    maml = Meta(args, criterion).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logging.info(maml)
    logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size,
                        split=[0, args.train_portion])
    mini_valid = MiniImagenet(args.data_path,
                              mode='train',
                              n_way=args.n_way,
                              k_shot=args.k_spt,
                              k_query=args.k_qry,
                              batch_size=args.batch_size,
                              resize=args.img_size,
                              split=[args.train_portion, 1])
    mini_test = MiniImagenet(args.data_path,
                             mode='train',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batch_size=args.test_batch_size,
                             resize=args.img_size,
                             split=[args.train_portion, 1])
    train_queue = DataLoader(mini,
                             args.meta_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    valid_queue = DataLoader(mini_valid,
                             args.meta_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    test_queue = DataLoader(mini_test,
                            args.meta_test_batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True)
    architect = Architect(maml.model, args)

    for epoch in range(args.epoch):
        # fetch batch_size num of episode each time
        logging.info('--------- Epoch: {} ----------'.format(epoch))

        train_accs = meta_train(train_queue, valid_queue, maml, architect,
                                device, criterion, epoch, writer)
        logging.info('[Epoch: {}]\t Train acc: {}'.format(epoch, train_accs))
        valid_accs = meta_test(test_queue, maml, device, epoch, writer)
        logging.info('[Epoch: {}]\t Test acc: {}'.format(epoch, valid_accs))

        genotype = maml.model.genotype()
        logging.info('genotype = %s', genotype)

        # logging.info(F.softmax(maml.model.alphas_normal, dim=-1))
        logging.info(F.softmax(maml.model.alphas_reduce, dim=-1))

        # Save the best meta model.
        new_pred = valid_accs[-1]
        if new_pred > best_pred:
            is_best = True
            best_pred = new_pred
        else:
            is_best = False
        saver.save_checkpoint(
            {
                'epoch':
                epoch,
                'state_dict':
                maml.module.state_dict()
                if isinstance(maml, nn.DataParallel) else maml.state_dict(),
                'best_pred':
                best_pred,
            }, is_best)
Пример #28
0
class Trainer(object):
    def __init__(self, config):

        self.config = config
        self.best_pred = 0.0

        # Define Saver
        self.saver = Saver(config)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir'])
        self.writer = self.summary.create_summary()
        
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)
        
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.config['network']['backbone'],
                        output_stride=self.config['image']['out_stride'],
                        sync_bn=self.config['network']['sync_bn'],
                        freeze_bn=self.config['network']['freeze_bn'])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']},
                        {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'],
                                    weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov'])

        # Define Criterion
        # whether to use class balanced weights
        if self.config['training']['use_balanced_weights']:
            classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'],
                                            self.config['training']['epochs'], len(self.train_loader))


        # Using cuda
        if self.config['network']['use_cuda']:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint

        if self.config['training']['weights_initialization']['use_pretrained_weights']:
            if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from']))

            if self.config['network']['use_cuda']:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'])
            else:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'})

            self.config['training']['start_epoch'] = checkpoint['epoch']

            if self.config['network']['use_cuda']:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

#            if not self.config['ft']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.config['dataset']['dataset_name'], image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        #save last checkpoint
        self.saver.save_checkpoint({
            'epoch': epoch + 1,
#            'state_dict': self.model.module.state_dict(),
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_pred': self.best_pred,
        }, is_best = False, filename='checkpoint_last.pth.tar')

        #if training on a subset reshuffle the data 
        if self.config['training']['train_on_subset']['enabled']:
            self.train_loader.dataset.shuffle_dataset()    


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
#                'state_dict': self.model.module.state_dict(),
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            },  is_best = True, filename='checkpoint_best.pth.tar')
Пример #29
0
class Evaluation(object):
    def __init__(self, args):

        self.args = args
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True}
        _, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        if args.network == 'searched-dense':
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy')
            cell_arch = np.load(cell_path)
            if self.args.C == 2:
                C_index = [5]
                #4_15_80e_40a_03-lr_5e-4wd_6e-4alr_1e-3awd 513x513 batch 4
                network_arch = [1, 2, 2, 2, 3, 2, 2, 1, 1, 1, 1, 2]
                low_level_layer = 0
            elif self.args.C == 3:
                C_index = [3, 7]
                network_arch = [1, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3]
                low_level_layer = 0
            elif self.args.C == 4:
                C_index = [2, 5, 8]
                network_arch = [1, 2, 3, 3, 2, 3, 3, 3, 3, 3, 2, 2]
                low_level_layer = 0

            model = ADD(network_arch,
                            C_index,
                            cell_arch,
                            self.nclass,
                            args,
                            low_level_layer)

        elif args.network.startswith('autodeeplab'):
            network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
            cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy')
            cell_arch = np.load(cell_path)
            low_level_layer = 2
            if self.args.C == 2:
                C_index = [5]
            elif self.args.C == 3:
                C_index = [3, 7]
            elif self.args.C == 4:
                C_index = [2, 5, 8]

            if args.network == 'autodeeplab-dense':
                model = ADD(network_arch,
                            C_index,
                            cell_arch,
                            self.nclass,
                            args,
                            low_level_layer)

            elif args.network == 'autodeeplab-baseline':
                model = Baselin_Model(network_arch,
                                    C_index,
                                    cell_arch,
                                    self.nclass,
                                    args,
                                    low_level_layer)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255).cuda()
        self.model = model

        # Define Evaluator
        self.evaluator = []
        for num in range(self.args.C):
            self.evaluator.append(Evaluator(self.nclass))

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()
        if args.confidence == 'edm':
            self.edm = EDM()
            self.edm = self.edm.cuda()
        else:
            self.edm = False

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                self.model.load_state_dict(checkpoint['state_dict'])


            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        if args.resume_edm is not None:
            if not os.path.isfile(args.resume_edm):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume_edm))
            checkpoint = torch.load(args.resume_edm)

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.edm.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.edm.load_state_dict(new_state_dict)

            else:
                self.edm.load_state_dict(checkpoint['state_dict'])


    def validation(self):
        self.model.eval()
        for e in self.evaluator:
            e.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                outputs = self.model(image)

            prediction = []
            """ Add batch sample into evaluator """
            for classifier_i in range(self.args.C):
                pred = torch.argmax(outputs[classifier_i], axis=1)
                prediction.append(pred)
                self.evaluator[classifier_i].add_batch(target, prediction[classifier_i])


            # Add batch sample into evaluator
        mIoU = []
        for classifier_i, e in enumerate(self.evaluator):
            mIoU.append(e.Mean_Intersection_over_Union())

        print("classifier_1_mIoU:{}, classifier_2_mIoU: {}".format(mIoU[0], mIoU[1]))

    def dynamic_inference(self, threshold, confidence):
        self.model.eval()
        self.evaluator[0].reset()
        if confidence == 'edm':
            self.edm.eval()
        time_meter = AverageMeter()

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        total_earlier_exit = 0
        confidence_value_avg = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output, earlier_exit, tic, confidence_value = self.model.dynamic_inference(image, threshold=threshold, confidence=confidence, edm=self.edm)
            total_earlier_exit += earlier_exit
            confidence_value_avg += confidence_value
            time_meter.update(tic)
            
            loss = self.criterion(output, target)
            pred = torch.argmax(output, axis=1)

            # Add batch sample into evaluator
            self.evaluator[0].add_batch(target, pred)
            
        mIoU = self.evaluator[0].Mean_Intersection_over_Union()

        print('Validation:')
        print("mIoU: {}".format(mIoU))
        print("mean_inference_time: {}".format(time_meter.average()))
        print("fps: {}".format(1.0/time_meter.average()))
        print("num_earlier_exit: {}".format(total_earlier_exit/500*100))
        print("avg_confidence: {}".format(confidence_value_avg/500))


    def mac(self):
        self.model.eval()
        with torch.no_grad():
            flops, params = get_model_complexity_info(self.model, (3, 1025, 2049), as_strings=True, print_per_layer_stat=False)
            print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
            print('{:<30}  {:<8}'.format('Number of parameters: ', params))
Пример #30
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        print(self.nclass, args.backbone, args.out_stride, args.sync_bn,
              args.freeze_bn)
        #2 resnet 16 False False

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            #image, target = sample['image'], sample['label']
            image, target = sample['trace'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['trace'], sample['label']
            #image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #31
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Generate .npy file for dataloader
        self.img_process(args)

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = getattr(modeling, args.model_name)(pretrained=args.pretrained)

        # Define Optimizer
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        # train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Criterion
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    # 将大图按unit_size的大小,每次stride的移动量进行裁剪。将分好的训练集和验证机以np数组形式存储在save_dir中,
    # 方便下次使用,并减少内存的占用。请将路径修改为自己的。
    def img_process(self, args):
        unit_size = args.base_size
        stride = unit_size  # int(unit_size/2)
        save_dir = os.path.join(
            '/data/dingyifeng/pytorch-jingwei-master/npy_process',
            str(unit_size))
        # npy_process
        if not os.path.exists(save_dir):

            Image.MAX_IMAGE_PIXELS = 100000000000
            # load train image 1
            img = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_1.png'
            )
            img = np.asarray(img)  #(50141, 47161, 4)
            anno_map = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_1_label.png'
            )
            anno_map = np.asarray(anno_map)  #(50141, 47161)

            length, width = img.shape[0], img.shape[1]
            x1, x2, y1, y2 = 0, unit_size, 0, unit_size
            Img1 = []  # 保存小图的数组
            Label1 = []  # 保存label的数组
            while (x1 < length):
                #判断横向是否越界
                if x2 > length:
                    x2, x1 = length, length - unit_size

                while (y1 < width):
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    im = img[x1:x2, y1:y2, :]
                    if 255 in im[:, :, -1]:  # 判断裁剪出来的小图中是否存在有像素点
                        Img1.append(im[:, :, 0:3])  # 添加小图
                        Label1.append(anno_map[x1:x2, y1:y2])  # 添加label

                    if y2 == width: break

                    y1 += stride
                    y2 += stride

                if x2 == length: break

                y1, y2 = 0, unit_size
                x1 += stride
                x2 += stride
            Img1 = np.array(Img1)  #(4123, 448, 448, 3)
            Label1 = np.array(Label1)  #(4123, 448, 448)

            # load train image 2
            img = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_2.png'
            )
            img = np.asarray(img)  #(50141, 47161, 4)
            anno_map = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_2_label.png'
            )
            anno_map = np.asarray(anno_map)  #(50141, 47161)

            length, width = img.shape[0], img.shape[1]
            x1, x2, y1, y2 = 0, unit_size, 0, unit_size
            Img2 = []  # 保存小图的数组
            Label2 = []  # 保存label的数组
            while (x1 < length):
                #判断横向是否越界
                if x2 > length:
                    x2, x1 = length, length - unit_size

                while (y1 < width):
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    im = img[x1:x2, y1:y2, :]
                    if 255 in im[:, :, -1]:  # 判断裁剪出来的小图中是否存在有像素点
                        Img2.append(im[:, :, 0:3])  # 添加小图
                        Label2.append(anno_map[x1:x2, y1:y2])  # 添加label

                    if y2 == width: break

                    y1 += stride
                    y2 += stride

                if x2 == length: break

                y1, y2 = 0, unit_size
                x1 += stride
                x2 += stride
            Img2 = np.array(Img2)  #(5072, 448, 448, 3)
            Label2 = np.array(Label2)  #(5072, 448, 448)

            Img = np.concatenate((Img1, Img2), axis=0)
            cat = np.concatenate((Label1, Label2), axis=0)

            # shuffle
            np.random.seed(1)
            assert (Img.shape[0] == cat.shape[0])
            shuffle_id = np.arange(Img.shape[0])
            np.random.shuffle(shuffle_id)
            Img = Img[shuffle_id]
            cat = cat[shuffle_id]

            os.mkdir(save_dir)
            print("=> generate {}".format(unit_size))
            # split train dataset
            images_train = Img  #[:int(Img.shape[0]*0.8)]
            categories_train = cat  #[:int(cat.shape[0]*0.8)]
            assert (len(images_train) == len(categories_train))
            np.save(os.path.join(save_dir, 'train_img.npy'), images_train)
            np.save(os.path.join(save_dir, 'train_label.npy'),
                    categories_train)
            # split val dataset
            images_val = Img[int(Img.shape[0] * 0.8):]
            categories_val = cat[int(cat.shape[0] * 0.8):]
            assert (len(images_val) == len(categories_val))
            np.save(os.path.join(save_dir, 'val_img.npy'), images_val)
            np.save(os.path.join(save_dir, 'val_label.npy'), categories_val)

            print("=> img_process finished!")
        else:
            print("{} file already exists!".format(unit_size))
        for x in locals().keys():
            del locals()[x]
        # 释放内存
        import gc
        gc.collect()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)