Beispiel #1
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        if args.loss_type == 'depth_loss_two_distributions':
            self.nclass = args.num_class + args.num_class2 + 1
        if args.loss_type == 'depth_avg_sigmoid_class':
            self.nclass = args.num_class + args.num_class2
        # Define network

        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)


        print("\nDefine models...\n")
        self.model_aprox_depth = DeepLab(num_classes=1,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)

        self.input_conv = nn.Conv2d(4, 3, 3, padding=1)
        # Using cuda
        if args.cuda:
            self.model_aprox_depth = torch.nn.DataParallel(self.model_aprox_depth, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model_aprox_depth)
            self.model_aprox_depth = self.model_aprox_depth.cuda()
            self.input_conv = self.input_conv.cuda()


        print("\nLoad checkpoints...\n")
        if not args.cuda:
            ckpt_aprox_depth = torch.load(args.ckpt, map_location='cpu')
            self.model_aprox_depth.load_state_dict(ckpt_aprox_depth['state_dict'])
        else:
            ckpt_aprox_depth = torch.load(args.ckpt)
            self.model_aprox_depth.module.load_state_dict(ckpt_aprox_depth['state_dict'])


        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        # set optimizer
        optimizer = torch.optim.Adam(train_params, args.lr)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        if 'depth' in args.loss_type:
            self.criterion = DepthLosses(weight=weight,
                                         cuda=args.cuda,
                                         min_depth=args.min_depth,
                                         max_depth=args.max_depth,
                                         num_class=args.num_class,
                                         cut_point=args.cut_point,
                                         num_class2=args.num_class2).build_loss(mode=args.loss_type)
            self.infer = DepthLosses(weight=weight,
                                     cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.max_depth,
                                     num_class=args.num_class,
                                     cut_point=args.cut_point,
                                     num_class2=args.num_class2)
            self.evaluator_depth = EvaluatorDepth(args.batch_size)
        else:
            self.criterion = SegmentationLosses(cuda=args.cuda, weight=weight).build_loss(mode=args.loss_type)
            self.evaluator = Evaluator(self.nclass)

        self.model, self.optimizer = model, optimizer

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if 'depth' in args.loss_type:
            self.best_pred = 1e6
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            if not args.cuda:
                checkpoint = torch.load(args.resume, map_location='cpu')
            else:
                checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.module.load_state_dict(state_dict, strict=False)
            else:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.load_state_dict(state_dict, strict=False)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            if 'depth' in args.loss_type:
                self.best_pred = 1e6
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        # add input layer to the model
        self.model = nn.Sequential(
            self.input_conv,
            self.model
        )
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
Beispiel #2
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        if args.loss_type == 'depth_loss_two_distributions':
            self.nclass = args.num_class + args.num_class2 + 1
        if args.loss_type == 'depth_avg_sigmoid_class':
            self.nclass = args.num_class + args.num_class2
        # Define network

        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)


        print("\nDefine models...\n")
        self.model_aprox_depth = DeepLab(num_classes=1,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)

        self.input_conv = nn.Conv2d(4, 3, 3, padding=1)
        # Using cuda
        if args.cuda:
            self.model_aprox_depth = torch.nn.DataParallel(self.model_aprox_depth, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model_aprox_depth)
            self.model_aprox_depth = self.model_aprox_depth.cuda()
            self.input_conv = self.input_conv.cuda()


        print("\nLoad checkpoints...\n")
        if not args.cuda:
            ckpt_aprox_depth = torch.load(args.ckpt, map_location='cpu')
            self.model_aprox_depth.load_state_dict(ckpt_aprox_depth['state_dict'])
        else:
            ckpt_aprox_depth = torch.load(args.ckpt)
            self.model_aprox_depth.module.load_state_dict(ckpt_aprox_depth['state_dict'])


        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        # set optimizer
        optimizer = torch.optim.Adam(train_params, args.lr)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        if 'depth' in args.loss_type:
            self.criterion = DepthLosses(weight=weight,
                                         cuda=args.cuda,
                                         min_depth=args.min_depth,
                                         max_depth=args.max_depth,
                                         num_class=args.num_class,
                                         cut_point=args.cut_point,
                                         num_class2=args.num_class2).build_loss(mode=args.loss_type)
            self.infer = DepthLosses(weight=weight,
                                     cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.max_depth,
                                     num_class=args.num_class,
                                     cut_point=args.cut_point,
                                     num_class2=args.num_class2)
            self.evaluator_depth = EvaluatorDepth(args.batch_size)
        else:
            self.criterion = SegmentationLosses(cuda=args.cuda, weight=weight).build_loss(mode=args.loss_type)
            self.evaluator = Evaluator(self.nclass)

        self.model, self.optimizer = model, optimizer

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if 'depth' in args.loss_type:
            self.best_pred = 1e6
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            if not args.cuda:
                checkpoint = torch.load(args.resume, map_location='cpu')
            else:
                checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.module.load_state_dict(state_dict, strict=False)
            else:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.load_state_dict(state_dict, strict=False)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            if 'depth' in args.loss_type:
                self.best_pred = 1e6
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        # add input layer to the model
        self.model = nn.Sequential(
            self.input_conv,
            self.model
        )
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        self.model_aprox_depth.eval()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.dataset == 'apollo_seg' or self.args.dataset == 'farsight_seg':
                target[target <= self.args.cut_point] = 0
                target[target > self.args.cut_point] = 1
            if image.shape[0] == 1:
                target = torch.cat([target, target], dim=0)
                image = torch.cat([image, image], dim=0)
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            aprox_depth = self.model_aprox_depth(image)
            aprox_depth = self.infer.sigmoid(aprox_depth)
            input = torch.cat([image, aprox_depth], dim=1)
            output = self.model(input)
            if self.args.loss_type == 'depth_sigmoid_loss_inverse':
                loss = self.criterion(output, target, inverse=True)
            else:
                loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                target[
                    torch.isnan(target)] = 0  # change nan values to zero for display (handle warning from tensorboard)
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step,
                                             n_class=self.args.num_class)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.model_aprox_depth.eval()
        if 'depth' in self.args.loss_type:
            self.evaluator_depth.reset()
        else:
            softmax = nn.Softmax(1)
            self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                aprox_depth = self.model_aprox_depth(image)
                aprox_depth = self.infer.sigmoid(aprox_depth)
                input = torch.cat([image, aprox_depth], dim=1)
                output = self.model(input)
                loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
            if 'depth' in self.args.loss_type:
                if self.args.loss_type == 'depth_loss':
                    pred = self.infer.pred_to_continous_depth(output)
                elif self.args.loss_type == 'depth_avg_sigmoid_class':
                    pred = self.infer.pred_to_continous_depth_avg(output)
                elif self.args.loss_type == 'depth_loss_combination':
                    pred = self.infer.pred_to_continous_combination(output)
                elif self.args.loss_type == 'depth_loss_two_distributions':
                    pred = self.infer.pred_to_continous_depth_two_distributions(output)
                elif 'depth_sigmoid_loss' in self.args.loss_type:
                    output = self.infer.sigmoid(output.squeeze(1))
                    pred = self.infer.depth01_to_depth(output)
                # Add batch sample into evaluator
                self.evaluator_depth.evaluateError(pred, target)
            else:
                output = softmax(output)
                pred = output.data.cpu().numpy()
                target = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                # Add batch sample into evaluator
                self.evaluator.add_batch(target, pred)
        if 'depth' in self.args.loss_type:
            # Fast test during the training
            MSE = self.evaluator_depth.averageError['MSE']
            RMSE = self.evaluator_depth.averageError['RMSE']
            ABS_REL = self.evaluator_depth.averageError['ABS_REL']
            LG10 = self.evaluator_depth.averageError['LG10']
            MAE = self.evaluator_depth.averageError['MAE']
            DELTA1 = self.evaluator_depth.averageError['DELTA1']
            DELTA2 = self.evaluator_depth.averageError['DELTA2']
            DELTA3 = self.evaluator_depth.averageError['DELTA3']

            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/MSE', MSE, epoch)
            self.writer.add_scalar('val/RMSE', RMSE, epoch)
            self.writer.add_scalar('val/ABS_REL', ABS_REL, epoch)
            self.writer.add_scalar('val/LG10', LG10, epoch)

            self.writer.add_scalar('val/MAE', MAE, epoch)
            self.writer.add_scalar('val/DELTA1', DELTA1, epoch)
            self.writer.add_scalar('val/DELTA2', DELTA2, epoch)
            self.writer.add_scalar('val/DELTA3', DELTA3, epoch)

            print('Validation:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print(
                "MSE:{}, RMSE:{}, ABS_REL:{}, LG10: {}\nMAE:{}, DELTA1:{}, DELTA2:{}, DELTA3: {}".format(MSE, RMSE,
                                                                                                         ABS_REL,
                                                                                                         LG10, MAE,
                                                                                                         DELTA1,
                                                                                                         DELTA2,
                                                                                                         DELTA3))
            new_pred = RMSE
            if new_pred < self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
        else:
            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/mIoU', mIoU, epoch)
            self.writer.add_scalar('val/Acc', Acc, epoch)
            self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
            self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
            print('Validation:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
            print('Loss: %.3f' % test_loss)

            new_pred = mIoU
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

        print('Loss: %.3f' % test_loss)
Beispiel #3
0
    def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        _, _, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        if args.loss_type == 'depth_loss_two_distributions':
            self.nclass = args.num_class + args.num_class2 + 1
        if args.loss_type == 'depth_avg_sigmoid_class':
            self.nclass = args.num_class + args.num_class2

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        self.model = model
        if self.args.loss_type == 'depth_multi_dnn':
            model2 = DeepLab(num_classes=args.num_class2,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
            self.model2 = model2
            model_seg = DeepLab(num_classes=2,
                                backbone=args.backbone,
                                output_stride=args.out_stride,
                                sync_bn=args.sync_bn,
                                freeze_bn=args.freeze_bn)
            self.model_seg = model_seg

        if self.args.loss_type == 'depth_with_aprox_depth':
            # add input layer to the model
            self.input_conv = nn.Conv2d(4, 3, 3, padding=1)
            model2 = DeepLab(num_classes=1,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
            self.model2 = model2  # aprox model

        # Define Criterion

        self.infer2 = None
        if self.args.loss_type == 'depth_multi_dnn':
            self.infer = DepthLosses(cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.cut_point,
                                     num_class=args.num_class,
                                     cut_point=-1,
                                     num_class2=-1)

            self.infer2 = DepthLosses(cuda=args.cuda,
                                      min_depth=args.cut_point,
                                      max_depth=args.max_depth,
                                      num_class=args.num_class2,
                                      cut_point=-1,
                                      num_class2=-1)
        else:
            self.infer = DepthLosses(cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.max_depth,
                                     num_class=args.num_class,
                                     cut_point=args.cut_point,
                                     num_class2=args.num_class2)

        self.softmax = nn.Softmax(1)

        # Define Evaluator
        self.evaluator_depth = EvaluatorDepth(args.batch_size)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            if self.args.loss_type == 'depth_multi_dnn':
                self.model2 = torch.nn.DataParallel(
                    self.model2, device_ids=self.args.gpu_ids)
                patch_replication_callback(self.model2)
                self.model2 = self.model2.cuda()

                self.model_seg = torch.nn.DataParallel(
                    self.model_seg, device_ids=self.args.gpu_ids)
                patch_replication_callback(self.model_seg)
                self.model_seg = self.model_seg.cuda()

            if self.args.loss_type == 'depth_with_aprox_depth':
                self.input_conv = self.input_conv.cuda()

                self.model = nn.Sequential(self.input_conv, self.model)
                self.model = torch.nn.DataParallel(
                    self.model, device_ids=self.args.gpu_ids)
                patch_replication_callback(self.model)
                self.model = self.model.cuda()

                self.model2 = torch.nn.DataParallel(
                    self.model2, device_ids=self.args.gpu_ids)
                patch_replication_callback(self.model2)
                self.model2 = self.model2.cuda()

        if not args.cuda:
            ckpt = torch.load(args.ckpt, map_location='cpu')
            if self.args.loss_type == 'depth_multi_dnn':
                ckpt2 = torch.load(args.ckpt2, map_location='cpu')
                ckpt_seg = torch.load(args.ckpt_seg, map_location='cpu')
                self.model2.load_state_dict(ckpt2['state_dict'])
                self.model_seg.load_state_dict(ckpt_seg['state_dict'])

            if self.args.loss_type == 'depth_with_aprox_depth':
                ckpt2 = torch.load(args.ckpt2, map_location='cpu')
                self.model2.load_state_dict(ckpt2['state_dict'])

            self.model.load_state_dict(ckpt['state_dict'])
        else:
            ckpt = torch.load(args.ckpt)
            if self.args.loss_type == 'depth_multi_dnn':
                ckpt2 = torch.load(args.ckpt2)
                ckpt_seg = torch.load(args.ckpt_seg)
                self.model2.module.load_state_dict(ckpt2['state_dict'])
                self.model_seg.load_state_dict(ckpt_seg['state_dict'])

            if self.args.loss_type == 'depth_with_aprox_depth':
                ckpt2 = torch.load(args.ckpt2)
                self.model2.module.load_state_dict(ckpt2['state_dict'])

            self.model.module.load_state_dict(ckpt['state_dict'])

        print("\nLoad checkpoints...\n")
Beispiel #4
0
class Eval(object):
    def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        _, _, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        if args.loss_type == 'depth_loss_two_distributions':
            self.nclass = args.num_class + args.num_class2 + 1
        if args.loss_type == 'depth_avg_sigmoid_class':
            self.nclass = args.num_class + args.num_class2

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        self.model = model
        if self.args.loss_type == 'depth_multi_dnn':
            model2 = DeepLab(num_classes=args.num_class2,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
            self.model2 = model2
            model_seg = DeepLab(num_classes=2,
                                backbone=args.backbone,
                                output_stride=args.out_stride,
                                sync_bn=args.sync_bn,
                                freeze_bn=args.freeze_bn)
            self.model_seg = model_seg

        if self.args.loss_type == 'depth_with_aprox_depth':
            # add input layer to the model
            self.input_conv = nn.Conv2d(4, 3, 3, padding=1)
            model2 = DeepLab(num_classes=1,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
            self.model2 = model2  # aprox model

        # Define Criterion

        self.infer2 = None
        if self.args.loss_type == 'depth_multi_dnn':
            self.infer = DepthLosses(cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.cut_point,
                                     num_class=args.num_class,
                                     cut_point=-1,
                                     num_class2=-1)

            self.infer2 = DepthLosses(cuda=args.cuda,
                                      min_depth=args.cut_point,
                                      max_depth=args.max_depth,
                                      num_class=args.num_class2,
                                      cut_point=-1,
                                      num_class2=-1)
        else:
            self.infer = DepthLosses(cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.max_depth,
                                     num_class=args.num_class,
                                     cut_point=args.cut_point,
                                     num_class2=args.num_class2)

        self.softmax = nn.Softmax(1)

        # Define Evaluator
        self.evaluator_depth = EvaluatorDepth(args.batch_size)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            if self.args.loss_type == 'depth_multi_dnn':
                self.model2 = torch.nn.DataParallel(
                    self.model2, device_ids=self.args.gpu_ids)
                patch_replication_callback(self.model2)
                self.model2 = self.model2.cuda()

                self.model_seg = torch.nn.DataParallel(
                    self.model_seg, device_ids=self.args.gpu_ids)
                patch_replication_callback(self.model_seg)
                self.model_seg = self.model_seg.cuda()

            if self.args.loss_type == 'depth_with_aprox_depth':
                self.input_conv = self.input_conv.cuda()

                self.model = nn.Sequential(self.input_conv, self.model)
                self.model = torch.nn.DataParallel(
                    self.model, device_ids=self.args.gpu_ids)
                patch_replication_callback(self.model)
                self.model = self.model.cuda()

                self.model2 = torch.nn.DataParallel(
                    self.model2, device_ids=self.args.gpu_ids)
                patch_replication_callback(self.model2)
                self.model2 = self.model2.cuda()

        if not args.cuda:
            ckpt = torch.load(args.ckpt, map_location='cpu')
            if self.args.loss_type == 'depth_multi_dnn':
                ckpt2 = torch.load(args.ckpt2, map_location='cpu')
                ckpt_seg = torch.load(args.ckpt_seg, map_location='cpu')
                self.model2.load_state_dict(ckpt2['state_dict'])
                self.model_seg.load_state_dict(ckpt_seg['state_dict'])

            if self.args.loss_type == 'depth_with_aprox_depth':
                ckpt2 = torch.load(args.ckpt2, map_location='cpu')
                self.model2.load_state_dict(ckpt2['state_dict'])

            self.model.load_state_dict(ckpt['state_dict'])
        else:
            ckpt = torch.load(args.ckpt)
            if self.args.loss_type == 'depth_multi_dnn':
                ckpt2 = torch.load(args.ckpt2)
                ckpt_seg = torch.load(args.ckpt_seg)
                self.model2.module.load_state_dict(ckpt2['state_dict'])
                self.model_seg.load_state_dict(ckpt_seg['state_dict'])

            if self.args.loss_type == 'depth_with_aprox_depth':
                ckpt2 = torch.load(args.ckpt2)
                self.model2.module.load_state_dict(ckpt2['state_dict'])

            self.model.module.load_state_dict(ckpt['state_dict'])

        print("\nLoad checkpoints...\n")

    def evaluate(self):
        self.model.eval()
        if self.args.loss_type == 'depth_multi_dnn':
            self.model2.eval()
            self.model_seg.eval()
        if self.args.loss_type == 'depth_with_aprox_depth':
            self.model2.eval()
        self.evaluator_depth.reset()
        tbar = tqdm(self.test_loader, desc='\r')

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                if 'depth_with_aprox_depth' in self.args.loss_type:
                    # import pdb;pdb.set_trace()
                    aprox_depth = self.model2(image)
                    aprox_depth = self.infer.sigmoid(aprox_depth)
                    input = torch.cat([image, aprox_depth], dim=1)
                    output = self.model(input)
                else:
                    output = self.model(image)

            pred = None
            if 'depth' in self.args.loss_type:
                if self.args.loss_type == 'depth_loss':
                    pred = self.infer.pred_to_continous_depth(output)
                elif self.args.loss_type == 'depth_avg_sigmoid_class':
                    pred = self.infer.pred_to_continous_depth_avg(output)
                elif self.args.loss_type == 'depth_loss_combination':
                    pred = self.infer.pred_to_continous_combination(output)
                elif self.args.loss_type == 'depth_loss_two_distributions':
                    pred = self.infer.pred_to_continous_depth_two_distributions(
                        output)
                elif self.args.loss_type == 'depth_multi_dnn':
                    with torch.no_grad():
                        output2 = self.model2(image)
                        output_seg = self.model_seg(image)
                    pred = self.infer.pred_to_continous_depth(output)
                    pred2 = self.infer2.pred_to_continous_depth(output2)
                    pred_seg = self.softmax(output_seg)
                    # join results
                    pred_seg = torch.argmax(pred_seg, dim=1)
                    pred = torch.where(pred_seg == 0, pred, pred2)

                elif 'depth_sigmoid_loss' in self.args.loss_type:
                    output = self.infer.sigmoid(output.squeeze(1))
                    if 'inverse' in self.args.loss_type:
                        pred = self.infer.depth01_to_depth(output, True)
                    else:
                        pred = self.infer.depth01_to_depth(output)

                elif 'depth_with_aprox_depth' in self.args.loss_type:
                    output = self.infer.sigmoid(output.squeeze(1))
                    pred = self.infer.depth01_to_depth(output)

            # Add batch sample into evaluator
            self.evaluator_depth.evaluateError(pred, target)

        # Fast test during the training
        MSE = self.evaluator_depth.averageError['MSE']
        RMSE = self.evaluator_depth.averageError['RMSE']
        ABS_REL = self.evaluator_depth.averageError['ABS_REL']
        LG10 = self.evaluator_depth.averageError['LG10']
        MAE = self.evaluator_depth.averageError['MAE']
        DELTA1 = self.evaluator_depth.averageError['DELTA1']
        DELTA2 = self.evaluator_depth.averageError['DELTA2']
        DELTA3 = self.evaluator_depth.averageError['DELTA3']

        print('Test:')
        print(
            "MSE:{}, RMSE:{}, ABS_REL:{}, LG10: {}\nMAE:{}, DELTA1:{}, DELTA2:{}, DELTA3: {}"
            .format(MSE, RMSE, ABS_REL, LG10, MAE, DELTA1, DELTA2, DELTA3))

    def evaluate2stage(self):
        self.model.eval()
        self.model2.eval()
        self.model_seg.eval()
        self.evaluator_depth.reset()
        tbar = tqdm(self.test_loader, desc='\r')

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
                output2 = self.model2(image)
                output_seg = self.model_seg(image)
            if self.infer.num_class > 1:
                pred = self.infer.pred_to_continous_depth(output)
                if self.infer2 is not None:
                    pred2 = self.infer2.pred_to_continous_depth(output2)
                    pred_seg = self.softmax(output_seg)
                    # join results
                    pred_seg = torch.argmax(pred_seg, dim=1)
                    pred = torch.where(pred_seg == 0, pred, pred2)
            else:
                output = self.infer.sigmoid(output)
                pred = self.infer.depth01_to_depth(
                    output).detach().cpu().numpy().squeeze()

            # Add batch sample into evaluator
            self.evaluator_depth.evaluateError(pred, target)

        # Fast test during the training
        MSE = self.evaluator_depth.averageError['MSE']
        RMSE = self.evaluator_depth.averageError['RMSE']
        ABS_REL = self.evaluator_depth.averageError['ABS_REL']
        LG10 = self.evaluator_depth.averageError['LG10']
        MAE = self.evaluator_depth.averageError['MAE']
        DELTA1 = self.evaluator_depth.averageError['DELTA1']
        DELTA2 = self.evaluator_depth.averageError['DELTA2']
        DELTA3 = self.evaluator_depth.averageError['DELTA3']

        print('Test:')
        print(
            "MSE:{}, RMSE:{}, ABS_REL:{}, LG10: {}\nMAE:{}, DELTA1:{}, DELTA2:{}, DELTA3: {}"
            .format(MSE, RMSE, ABS_REL, LG10, MAE, DELTA1, DELTA2, DELTA3))
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch DeeplabV3Plus Training")
    parser.add_argument('--in-path',
                        type=str,
                        required=True,
                        help='image to test')
    parser.add_argument('--out-path',
                        type=str,
                        required=True,
                        help='mask image to save')
    parser.add_argument('--backbone',
                        type=str,
                        default='resnet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')
    parser.add_argument('--ckpt',
                        type=str,
                        default='deeplab-resnet.pth',
                        help='saved model')
    parser.add_argument('--out-stride',
                        type=int,
                        default=16,
                        help='network output stride (default: 8)')
    parser.add_argument('--num_class',
                        type=int,
                        default=50,
                        help='number of classes to predict')
    parser.add_argument('--min_depth',
                        type=float,
                        default=0.0,
                        help='min depth to predict')
    parser.add_argument('--max_depth',
                        type=float,
                        default=655.0,
                        help='max depth to predict')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=True,
                        help='disables CUDA training')
    parser.add_argument('--gpu-ids',
                        type=str,
                        default='0',
                        help='use which gpu to train, must be a \
                        comma-separated list of integers only (default=0)')
    parser.add_argument('--dataset',
                        type=str,
                        default='pascal',
                        choices=['pascal', 'coco', 'cityscapes', 'apollo'],
                        help='dataset name (default: pascal)')
    parser.add_argument('--crop-size',
                        type=int,
                        default=513,
                        help='crop image size')
    parser.add_argument('--sync-bn',
                        type=bool,
                        default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='whether to freeze bn parameters (default: False)')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        try:
            args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
        except ValueError:
            raise ValueError(
                'Argument --gpu_ids must be a comma-separated list of integers only'
            )

    if args.sync_bn is None:
        if args.cuda and len(args.gpu_ids) > 1:
            args.sync_bn = True
        else:
            args.sync_bn = False

    model = DeepLab(num_classes=args.num_class,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=args.sync_bn,
                    freeze_bn=args.freeze_bn)
    if not args.cuda:
        ckpt = torch.load(args.ckpt, map_location='cpu')
    else:
        ckpt = torch.load(args.ckpt)
    model.load_state_dict(ckpt['state_dict'])
    infer = DepthLosses(cuda=args.cuda,
                        min_depth=args.min_depth,
                        max_depth=args.max_depth,
                        num_class=args.num_class)

    composed_transforms = transforms.Compose([
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        tr.ToTensor()
    ])

    image = Image.open(args.in_path).convert('RGB')
    target = Image.open(args.in_path).convert('L')
    sample = {'image': image, 'label': target}
    tensor_in = composed_transforms(sample)['image'].unsqueeze(0)

    model.eval()
    with torch.no_grad():
        output = model(tensor_in)
        if infer.num_class > 1:
            output = infer.pred_to_continous_depth(
                output).detach().cpu().numpy().squeeze()
        else:
            output = infer.sigmoid(output)
            output = infer.depth01_to_depth(
                output).detach().cpu().numpy().squeeze()
    # gt_path = args.out_path.split('.')[0] + "_GT.png"
    plt.imsave(args.out_path, output)
    # plt.imsave(gt_path, target)

    print("\nDone !!! \n")
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch DeeplabV3Plus Training")
    parser.add_argument('--in-path',
                        type=str,
                        required=True,
                        help='image to test')
    parser.add_argument('--out-path',
                        type=str,
                        required=True,
                        help='mask image to save')
    parser.add_argument('--backbone',
                        type=str,
                        default='resnet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')

    parser.add_argument('--ckpt_near',
                        type=str,
                        default='deeplab-resnet.pth',
                        help='saved model')
    parser.add_argument('--ckpt_far',
                        type=str,
                        default='deeplab-resnet.pth',
                        help='saved model')
    parser.add_argument('--ckpt_seg',
                        type=str,
                        default='deeplab-resnet.pth',
                        help='saved model')

    parser.add_argument('--num_class_near',
                        type=int,
                        default=100,
                        help='number of classes to predict')
    parser.add_argument('--num_class_far',
                        type=int,
                        default=50,
                        help='number of classes to predict')

    parser.add_argument('--min_depth_near',
                        type=float,
                        default=0.0,
                        help='min depth to predict')
    parser.add_argument('--min_depth_far',
                        type=float,
                        default=100.0,
                        help='min depth to predict')

    parser.add_argument('--max_depth_near',
                        type=float,
                        default=100.0,
                        help='max depth to predict')
    parser.add_argument('--max_depth_far',
                        type=float,
                        default=655.0,
                        help='max depth to predict')

    parser.add_argument('--out-stride',
                        type=int,
                        default=16,
                        help='network output stride (default: 8)')

    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=True,
                        help='disables CUDA training')
    parser.add_argument('--gpu-ids',
                        type=str,
                        default='0',
                        help='use which gpu to train, must be a \
                        comma-separated list of integers only (default=0)')
    parser.add_argument('--crop-size',
                        type=int,
                        default=513,
                        help='crop image size')
    parser.add_argument('--sync-bn',
                        type=bool,
                        default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='whether to freeze bn parameters (default: False)')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        try:
            args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
        except ValueError:
            raise ValueError(
                'Argument --gpu_ids must be a comma-separated list of integers only'
            )

    if args.sync_bn is None:
        if args.cuda and len(args.gpu_ids) > 1:
            args.sync_bn = True
        else:
            args.sync_bn = False

    print("\nDefine models...\n")
    model_near = DeepLab(num_classes=args.num_class_near,
                         backbone=args.backbone,
                         output_stride=args.out_stride,
                         sync_bn=args.sync_bn,
                         freeze_bn=args.freeze_bn)
    model_far = DeepLab(num_classes=args.num_class_far,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
    model_seg = DeepLab(num_classes=2,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
    if not args.cuda:
        ckpt_near = torch.load(args.ckpt_near, map_location='cpu')
        ckpt_far = torch.load(args.ckpt_far, map_location='cpu')
        ckpt_seg = torch.load(args.ckpt_seg, map_location='cpu')
    else:
        ckpt_near = torch.load(args.ckpt_near)
        ckpt_far = torch.load(args.ckpt_far)
        ckpt_seg = torch.load(args.ckpt_seg)

    print("\nLoad checkpoints...\n")
    model_near.load_state_dict(ckpt_near['state_dict'])
    model_far.load_state_dict(ckpt_far['state_dict'])
    model_seg.load_state_dict(ckpt_seg['state_dict'])

    infer_near = DepthLosses(cuda=args.cuda,
                             min_depth=args.min_depth_near,
                             max_depth=args.max_depth_near,
                             num_class=args.num_class_near)

    infer_far = DepthLosses(cuda=args.cuda,
                            min_depth=args.min_depth_far,
                            max_depth=args.max_depth_far,
                            num_class=args.num_class_far)

    softmax = nn.Softmax(1)

    composed_transforms = transforms.Compose([
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        tr.ToTensor()
    ])

    image = Image.open(args.in_path).convert('RGB')
    target = Image.open(args.in_path).convert('L')
    sample = {'image': image, 'label': target}
    tensor_in = composed_transforms(sample)['image'].unsqueeze(0)

    model_near.eval()
    model_far.eval()
    model_seg.eval()
    with torch.no_grad():
        # pass the image in all models
        print("\nPass near model...\n")
        output_near = model_near(tensor_in)
        output_near = infer_near.pred_to_continous_depth(
            output_near).detach().cpu().numpy().squeeze()
        print("\nPass far model...\n")
        output_far = model_far(tensor_in)
        output_far = infer_far.pred_to_continous_depth(
            output_far).detach().cpu().numpy().squeeze()
        print("\nPass seg model...\n")
        output_seg = model_seg(tensor_in)

        output_seg = softmax(output_seg)

        output_seg = torch.argmax(output_seg,
                                  dim=1).detach().cpu().numpy().squeeze()

        # compose the image based on segmentation mask
        print("\nFuse results...\n")
        output = np.where(output_seg == 0, output_near, output_far)

    plt.imsave(args.out_path, output)
    seg_path = args.out_path.split('.')[0] + "_seg.png"
    near_path = args.out_path.split('.')[0] + "_near.png"
    far_path = args.out_path.split('.')[0] + "_far.png"
    gt_path = args.out_path.split('.')[0] + "_GT.png"
    plt.imsave(seg_path, output_seg)
    plt.imsave(near_path, output_near)
    plt.imsave(far_path, output_far)