Beispiel #1
0
class TrainModel():
    def name(self):
        return 'Train Model'

    def initialize(self, opt):
        self.opt = opt
        self.opt.imageSize = self.opt.imageSize if len(
            self.opt.imageSize) == 2 else self.opt.imageSize * 2
        self.gpu_ids = ''
        self.batchSize = self.opt.batchSize
        self.checkpoints_path = os.path.join(self.opt.checkpoints,
                                             self.opt.name)
        self.scheduler = None
        self.create_save_folders()

        # criterion to evaluate the val split
        self.criterion_eval = MSEScaledError()
        self.mse_scaled_error = MSEScaledError()

        self.opt.print_freq = self.opt.display_freq

        self.visualizer = Visualizer(opt)

        if self.opt.resume and self.opt.display_id > 0:
            self.load_plot_data()
        elif opt.train:
            self.start_epoch = 1
            self.best_val_error = 999.9
        # self.print_save_options()

        # Logfile
        self.logfile = open(os.path.join(self.checkpoints_path, 'logfile.txt'),
                            'a')
        if opt.validate:
            self.logfile_val = open(
                os.path.join(self.checkpoints_path, 'logfile_val.txt'), 'a')

        # Prepare a random seed that will be the same for everyone
        # opt.manualSeed = random.randint(1, 10000)   # fix seed
        # print("Random Seed: ", opt.manualSeed)
        # # random.seed(opt.manualSeed)
        # torch.manual_seed(opt.manualSeed)

        self.random_seed = 123
        random.seed(self.random_seed)
        torch.cuda.manual_seed_all(self.random_seed)
        torch.manual_seed(self.random_seed)
        if opt.cuda:
            self.cuda = torch.device(
                'cuda:0')  # set externally. ToDo: set internally
            torch.cuda.manual_seed(self.random_seed)

        # uses the inbuilt cudnn auto-tuner to find the fastest convolution algorithms.
        cudnn.benchmark = self.opt.use_cudnn_benchmark  # using too much memory - use when not in astroboy
        cudnn.enabled = True

        if not opt.train and not opt.test and not opt.resume:
            raise Exception("You have to set --train or --test")

        if torch.cuda.is_available and not opt.cuda:
            print(
                "WARNING: You have a CUDA device, so you should run WITHOUT --cpu"
            )
        if not torch.cuda.is_available and opt.cuda:
            raise Exception("No GPU found, run WITH --cpu")

    def set_input(self, input):
        self.input = input

    def create_network(self):
        netG = networks.define_G(input_nc=self.opt.input_nc,
                                 output_nc=self.opt.output_nc,
                                 ngf=64,
                                 net_architecture=self.opt.net_architecture,
                                 opt=self.opt,
                                 gpu_ids='')

        if self.opt.cuda:
            netG = netG.cuda()
        return netG

    def get_optimizerG(self, network, lr, weight_decay=0.0):
        generator_params = filter(lambda p: p.requires_grad,
                                  network.parameters())
        return optim.Adam(generator_params,
                          lr=lr,
                          betas=(self.opt.beta1, 0.999),
                          weight_decay=weight_decay)

    def get_checkpoint(self, epoch):
        pass

    def train_batch(self):
        """Each method has a different implementation"""
        pass

    def display_gradients_norms(self):
        return 'nothing yet'

    def get_current_errors_display(self):
        pass

    def get_regression_criterion(self):
        if self.opt.regression_loss == 'L1':
            return nn.L1Loss()

    def get_variable(self, tensor, requires_grad=False):
        if self.opt.cuda:
            tensor = tensor.cuda()
        return Variable(tensor, requires_grad=requires_grad)

    def restart_variables(self):
        self.it = 0
        self.rmse = 0
        self.n_images = 0

    def train(self, data_loader, val_loader=None):
        self.data_loader = data_loader
        self.len_data_loader = len(
            self.data_loader)  # check if gonna use elsewhere
        self.total_iter = 0
        for epoch in range(self.start_epoch, self.opt.nEpochs):
            self.restart_variables()
            self.data_iter = iter(self.data_loader)
            # self.pbar = tqdm(range(self.len_data_loader))
            self.pbar = range(self.len_data_loader)
            # while self.it < self.len_data_loader:
            for self.it in self.pbar:
                if self.opt.optim == 'SGD':
                    self.scheduler.step()

                self.total_iter += self.opt.batchSize

                self.netG.train(True)

                iter_start_time = time.time()

                self.train_batch()

                d_time = (time.time() - iter_start_time) / self.opt.batchSize

                # print errors
                self.print_current_errors(epoch, d_time)

                # display errors
                self.display_current_results(epoch)

                # Validate
                self.evaluate(val_loader, epoch)

            # save checkpoint
            self.save_checkpoint(epoch, is_best=0)

        self.logfile.close()

        if self.opt.validate:
            self.logfile_val.close()

    def get_next_batch(self):
        # self.it += 1 # important for GANs
        rgb_cpu, depth_cpu = self.data_iter.next()
        # depth_cpu = depth_cpu[0]
        self.input.data.resize_(rgb_cpu.size()).copy_(rgb_cpu)
        # self.target.data.resize_(depth_cpu.size()).copy_(depth_cpu)

    def apply_valid_pixels_mask(self, *data, value=0.0):
        # self.nomask_outG = data[0].data   # for displaying purposes
        mask = (data[1].data > value).to(self.cuda, dtype=torch.float32)

        masked_data = []
        for d in data:
            masked_data.append(d * mask)

        return masked_data, mask.sum()

    def update_learning_rate(self, epoch):
        if epoch > self.opt.niter_decay and self.opt.use_cgan:  # but independs if conditional or not
            # Linear decay for discriminator
            [self.opt.d_lr,
             self.optimD] = self._update_learning_rate(self.opt.niter_decay,
                                                       self.opt.d_lr,
                                                       self.optimD)
            [self.opt.lr,
             self.optimG] = self._update_learning_rate(self.opt.niter_decay,
                                                       self.opt.lr,
                                                       self.optimG)

    def _update_learning_rate(self, niter_decay, old_lr, optim):
        lr = old_lr - old_lr / niter_decay
        for param_group in optim.param_groups:
            param_group['lr'] = lr
        return lr, optim

    # CONTROL FUNCTIONS OF THE ARCHITECTURE

    def _get_plot_data_filename(self, phase):
        return os.path.join(
            self.checkpoints_path,
            'plot_data' + ('' if phase == 'train' else '_' + phase) + '.p')

    def save_static_plot_image():
        return None

    def save_interactive_plot_image():
        return None

    def _save_plot_data(self, plot_data, filename):
        # save
        pickle.dump(plot_data, open(filename, 'wb'))

    def save_plot_data(self):
        self._save_plot_data(self.visualizer.plot_data,
                             self._get_plot_data_filename('train'))
        if self.opt.validate and self.total_iter > self.opt.val_freq:
            self._save_plot_data(self.visualizer.plot_data_val,
                                 self._get_plot_data_filename('val'))

    def _load_plot_data(self, filename):
        # verify if file exists
        if not os.path.isfile(filename):
            raise Exception(
                'In _load_plot_data file {} doesnt exist.'.format(filename))
        else:
            return pickle.load(open(filename, "rb"))

    def load_plot_data(self):
        self.visualizer.plot_data = self._load_plot_data(
            self._get_plot_data_filename('train'))
        if self.opt.validate:
            self.visualizer.plot_data_val = self._load_plot_data(
                self._get_plot_data_filename('val'))

    def save_checkpoint(self, epoch, is_best):
        if epoch % self.opt.save_checkpoint_freq == 0 or is_best:
            checkpoint = self.get_checkpoint(epoch)
            checkpoint_filename = '{}/{:04}.pth.tar'.format(
                self.checkpoints_path, epoch)
            self._save_checkpoint(
                checkpoint, is_best=is_best, filename=checkpoint_filename
            )  # standart is_best=0 here cause we didn' evaluate on validation data
            # save plot data as well

    def _save_checkpoint(self, state, is_best, filename):
        print("Saving checkpoint...")
        # uncomment next 2 lines if we still want per epoch
        torch.save(state, filename)
        shutil.copyfile(
            filename, os.path.join(os.path.dirname(filename),
                                   'latest.pth.tar'))

        # comment next 2 lines if necessary if using last two lines
        # filename = os.path.join(self.checkpoints_path, 'latest.pth.tar')
        # torch.save(state, os.path.join(self.checkpoints_path, 'latest.pth.tar'))

        if is_best:
            shutil.copyfile(
                filename, os.path.join(self.checkpoints_path, 'best.pth.tar'))

    def create_save_folders(self):
        if self.opt.train:
            os.system('mkdir -p {0}'.format(self.checkpoints_path))
        # if self.opt.save_samples:
        #     subfolders = ['input', 'target', 'results', 'output']
        #     self.save_samples_path = os.path.join('results/train_results/', self.opt.name)
        #     for subfolder in subfolders:
        #         path = os.path.join(self.save_samples_path, subfolder)
        #         os.system('mkdir -p {0}'.format(path))
        #     if self.opt.test:
        #         self.save_samples_path = os.path.join('results/test_results/', self.opt.name)
        #         self.save_samples_path = os.path.join(self.save_samples_path, self.opt.epoch)
        #         for subfolder in subfolders:
        #             path = os.path.join(self.save_samples_path, subfolder)
        #             os.system('mkdir -p {0}'.format(path))

    def print_save_options(self):
        options_file = open(os.path.join(self.checkpoints_path, 'options.txt'),
                            'w')
        args = dict((arg, getattr(self.opt, arg)) for arg in dir(self.opt)
                    if not arg.startswith('_'))
        print('---Options---')
        for k, v in sorted(args.items()):
            option = '{}: {}'.format(k, v)
            # print options
            print(option)
            # save options in file
            options_file.write(option + '\n')

        options_file.close()

    def mean_errors(self):
        pass

    def get_current_errors(self):
        pass

    def print_current_errors(self, epoch, d_time):
        if self.total_iter % self.opt.print_freq == 0:
            self.mean_errors()
            errors = self.get_current_errors()
            message = self.visualizer.print_errors(errors, epoch, self.it,
                                                   self.len_data_loader,
                                                   d_time)

            # self.pbar.set_description(message)
            print(message)
        # self.pbar.refresh()

    # def print_epoch_error(error):
    #     pass

    def get_current_visuals(self):
        pass

    def display_current_results(self, epoch):
        if self.opt.display_id > 0 and self.total_iter % self.opt.display_freq == 0:

            errors = self.get_current_errors_display()
            self.visualizer.display_errors(
                errors, epoch,
                float(self.it) / self.len_data_loader)

            visuals = self.get_current_visuals()

            self.visualizer.display_images(visuals, epoch)

            # save printed errors to logfile
            self.visualizer.save_errors_file(self.logfile)

    def evaluate(self, data_loader, epoch):
        if self.opt.validate and self.total_iter % self.opt.val_freq == 0:
            val_error = self.get_eval_error(data_loader, self.netG,
                                            self.criterion_eval, epoch)

            # errors = OrderedDict([('LossL1', self.e_reg if self.opt.reg_type == 'L1' else self.L1error),
            #                      ('ValError', val_error.item())])
            errors = OrderedDict([('RMSE', self.rmse_epoch),
                                  ('RMSEVal', val_error)])
            self.visualizer.display_errors(errors,
                                           epoch,
                                           float(self.it) /
                                           self.len_data_loader,
                                           phase='val')
            message = self.visualizer.print_errors(errors, epoch, self.it,
                                                   len(data_loader), 0)
            print('[Validation] ' + message)
            self.visualizer.save_errors_file(self.logfile_val)
            self.save_plot_data()
            # save best models
            is_best = self.best_val_error > val_error
            if is_best:  # and not self.opt.not_save_val_model:
                print("Updating BEST model (epoch {}, iters {})\n".format(
                    epoch, self.total_iter))
                self.best_val_error = val_error
                self.save_checkpoint(epoch, is_best)

    def get_eval_error(self, val_loader, model, criterion, epoch):
        """
        Validate every self.opt.val_freq epochs
        """
        # no need to switch to model.eval because we want to keep dropout layers. Do I gave to ignore batch norm layers?
        cumulated_rmse = 0
        batchSize = 1
        input = self.get_variable(torch.FloatTensor(batchSize, 3,
                                                    self.opt.imageSize[0],
                                                    self.opt.imageSize[1]),
                                  requires_grad=False)
        mask = self.get_variable(torch.FloatTensor(batchSize, 1,
                                                   self.opt.imageSize[0],
                                                   self.opt.imageSize[1]),
                                 requires_grad=False)
        target = self.get_variable(
            torch.FloatTensor(batchSize, 1, self.opt.imageSize[0],
                              self.opt.imageSize[1]))
        # model.eval()
        model.train(False)
        pbar_val = tqdm(val_loader)
        for i, (rgb_cpu, depth_cpu) in enumerate(pbar_val):
            pbar_val.set_description('[Validation]')
            input.data.resize_(rgb_cpu.size()).copy_(rgb_cpu)
            target.data.resize_(depth_cpu.size()).copy_(depth_cpu)

            if self.opt.use_padding:
                from torch.nn import ReflectionPad2d

                self.opt.padding = self.get_padding_image(input)

                input = ReflectionPad2d(self.opt.padding)(input)
                target = ReflectionPad2d(self.opt.padding)(target)

            # get output of the network
            with torch.no_grad():
                outG = model.forward(input)
            # apply mask
            nomask_outG = outG.data  # for displaying purposes
            mask_ByteTensor = self.get_mask(target.data)
            mask.data.resize_(mask_ByteTensor.size()).copy_(mask_ByteTensor)
            outG = outG * mask
            target = target * mask
            cumulated_rmse += sqrt(criterion(outG, target, mask,
                                             no_mask=False))

            if (i == 1):
                self.visualizer.display_images(OrderedDict([
                    ('input', input.data), ('gt', target.data),
                    ('output', nomask_outG)
                ]),
                                               epoch='val {}'.format(epoch),
                                               phase='val')

        return cumulated_rmse / len(val_loader)

    def get_mask(self, data, value=0.0):
        return (target.data > 0.0)

    def get_padding(self, dim):
        final_dim = (dim // 32 + 1) * 32
        return final_dim - dim

    def get_padding_image(self, img):
        # get tensor dimensions
        h, w = img.size()[2:]
        w_pad, h_pad = self.get_padding(w), self.get_padding(h)

        pwr = w_pad // 2
        pwl = w_pad - pwr
        phb = h_pad // 2
        phu = h_pad - phb

        # pwl, pwr, phu, phb
        return (pwl, pwr, phu, phb)

    def adjust_learning_rate(self, initial_lr, optimizer, epoch):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        lr = initial_lr * (0.1**(epoch // self.opt.niter_decay))
        if epoch % self.opt.niter_decay == 0:
            print("LEARNING RATE DECAY HERE: lr = {}".format(lr))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
Beispiel #2
0
class TestModel(GenericTestModel):
    def initialize(self, opt):
        # GenericTestModel.initialize(self, opt)
        self.opt = opt
        self.get_color_palette()
        self.opt.imageSize = self.opt.imageSize if len(
            self.opt.imageSize) == 2 else self.opt.imageSize * 2
        self.gpu_ids = ''
        self.batchSize = self.opt.batchSize
        self.checkpoints_path = os.path.join(self.opt.checkpoints,
                                             self.opt.name)
        self.create_save_folders()
        self.opt.use_semantics = (('multitask' in self.opt.model)
                                  or ('semantics' in self.opt.model))

        self.netG = self.load_network()
        # self.opt.dfc_preprocessing = 2
        # self.data_loader, _ = CreateDataLoader(opt, Dataset)

        # visualizer
        self.visualizer = Visualizer(self.opt)
        if 'semantics' in self.opt.tasks:
            from util.util import get_color_palette
            self.opt.color_palette = np.array(
                get_color_palette(self.opt.dataset_name))
            # self.opt.color_palette = list(self.opt.color_palette.reshape(-1))
            # st()

    # def initialize(self, opt):
    #     GenericTestModel.initialize(self, opt)
    #     self.get_color_palette()

    def name(self):
        return 'Raster Test Model'

    def get_color_palette(self):
        if self.opt.dataset_name == 'dfc':
            self.opt.color_palette = [[0, 0, 0], [0, 205, 0], [127, 255, 0],
                                      [46, 139, 87], [0, 139, 0], [0, 70, 0],
                                      [160, 82, 45], [0, 255, 255],
                                      [255, 255, 255], [216, 191, 216],
                                      [255, 0, 0], [170, 160, 150],
                                      [128, 128, 128], [160, 0, 0], [80, 0, 0],
                                      [232, 161, 24], [255, 255, 0],
                                      [238, 154, 0], [255, 0, 255],
                                      [0, 0, 255], [176, 196, 222]]
        elif self.opt.dataset_name == 'isprs':
            self.opt.color_palette = [
                [0, 0, 0],
                [255, 255, 255],
                [0, 0, 255],
                [0, 255, 255],
                [0, 255, 0],
                [255, 255, 0],
                [255, 0, 0],
            ]

    def load_network(self):
        if self.opt.epoch is not 'latest' or self.opt.epoch is not 'best':
            self.opt.epoch = self.opt.epoch.zfill(4)
        checkpoint_file = os.path.join(self.checkpoints_path,
                                       self.opt.epoch + '.pth.tar')
        if os.path.isfile(checkpoint_file):
            print("Loading {} checkpoint of model {} ...".format(
                self.opt.epoch, self.opt.name))
            checkpoint = torch.load(checkpoint_file)
            self.start_epoch = int(checkpoint['epoch'])
            self.opt.net_architecture = checkpoint['arch_netG']
            try:
                self.opt.d_block_type = checkpoint['d_block_type']
                # Extra options for raster:
                self.opt.which_raster = checkpoint['which_raster']
                self.opt.model = checkpoint['model']
                self.opt.tasks = checkpoint['tasks']
                self.opt.outputs_nc = checkpoint['outputs_nc']
                self.opt.n_classes = checkpoint['n_classes']
            except:
                pass
            self.opt.use_skips = checkpoint['use_skips']
            self.opt.model = checkpoint['model']
            self.opt.dfc_preprocessing = checkpoint['dfc_preprocessing']
            self.opt.mtl_method = checkpoint['mtl_method']
            self.opt.tasks = checkpoint['tasks']
            self.opt.outputs_nc = checkpoint['outputs_nc']
            netG = self.create_G_network()
            pretrained_dict = checkpoint['state_dictG']
            pattern = re.compile(
                r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
            )
            for key in list(pretrained_dict.keys()):
                res = pattern.match(key)
                if res:
                    new_key = res.group(1) + res.group(2)
                    pretrained_dict[new_key] = pretrained_dict[key]
                    del pretrained_dict[key]
            netG.load_state_dict(pretrained_dict)
            if self.opt.cuda:
                netG = netG.cuda()
            self.best_val_error = checkpoint['best_pred']

            print("Loaded model from epoch {}".format(self.start_epoch))
            return netG
        else:
            print("Couldn't find checkpoint on path: {}".format(
                self.checkpoints_path + '/' + self.opt.epoch))

    def save_raster_png(self, data, filename):
        if 'semantics' in filename:
            from util.util import labels_to_colors
            image_save = Image.fromarray(np.squeeze(
                labels_to_colors(data,
                                 self.opt.color_palette).astype(np.int8)),
                                         mode='RGB').convert(
                                             'P',
                                             palette=Image.ADAPTIVE,
                                             colors=256)
            image_save.save(filename)

    def save_merged_rasters(self, datatype, fileroot=None):
        import rasterio
        from rasterio.merge import merge
        from rasterio.plot import show
        from os.path import join
        import argparse
        import glob

        if fileroot == None:
            fileroot = datatype

        root = '{}/{}/{}*.tif'.format(self.save_samples_path, datatype,
                                      fileroot)
        filename = '{}/{}/merged_{}.tif'.format(self.save_samples_path,
                                                datatype, fileroot)

        files = glob.glob(join(root))
        mosaic_rasters = [rasterio.open(file) for file in files]

        mosaic, out_transform = merge(mosaic_rasters)

        meta = (rasterio.open(files[0])).meta

        meta.update({
            "driver": "GTiff",
            "height": mosaic.shape[1],
            "width": mosaic.shape[2],
            "transform": out_transform
        })

        with rasterio.open(filename, "w", **meta) as dest:
            dest.write(mosaic)

        filename = '{}/{}/{}_merged.png'.format(self.save_samples_path,
                                                datatype, datatype)
        self.save_raster_png(mosaic, filename)

        if 'output' in filename or 'target' in filename:
            self.save_height_colormap(filename, mosaic)

    def test_raster(self):
        if 'semantics' in self.opt.model:
            from dataloader.dataset_raster import load_rgb_and_label as load_data
            self.test_raster_notarget(load_data)
        else:
            from dataloader.dataset_raster import load_rgb_and_labels as load_data
            self.test_raster_target(load_data)

    def initialize_test_bayesian(self, opt):
        from dataloader.dataset_bank import dataset_dfc
        print('Test phase using {} split.'.format(self.opt.test_split))
        phase = 'test'

        input_list, target_path = dataset_dfc(
            self.opt.dataroot,
            data_split=self.opt.test_split,
            phase='test',
            model=self.opt.model,
            which_raster=self.opt.which_raster)

        # Sanity check : raise an error if some files do not exist
        for f in input_list + target_path:
            if not os.path.isfile(f):
                raise KeyError('{} is not a file !'.format(f))

        from dataloader.dataset_raster import load_rgb_and_labels as load_data
        self.data_loader = [
            (rgb, depth, meta, depth_patch_shape)
            for rgb, depth, meta, depth_patch_shape in load_data(
                input_list,
                target_path,
                phase,
                self.opt.dfc_preprocessing,
                which_raster=self.opt.which_raster,
                use_semantics=False,
                save_semantics=self.opt.save_semantics)
        ]  # false because we do not have the GT
        # no error in save semantics, same value to both variables
        self.netG.eval()
        self.netG.apply(self.activate_dropout)

    def activate_dropout(self, m):
        if type(m) == nn.Dropout:
            # print(m)
            m.train()

    def get_meta_data(self):
        return self.meta_data

    def get_shape(self):
        return self.shape

    def get_data_loader_size(self):
        return len(self.data_loader)

    def test_bayesian(self, it, n_iters):
        error_list = []
        outG_list = []

        use_semantics = self.opt.use_semantics
        self.opt.use_semantics = False
        # self.augmentation = augmentation
        imageSize = self.opt.imageSize if len(
            self.opt.imageSize) == 2 else self.opt.imageSize * 2
        test_stride = self.opt.test_stride if len(
            self.opt.test_stride) == 2 else self.opt.test_stride * 2

        # create a matrix with a gaussian distribution to be the weights during reconstruction
        prob_matrix = self.gaussian_kernel(imageSize[0], imageSize[1])

        # for it, (input, target, meta_data, depth_patch_shape) in enumerate(tqdm(self.data_loader)):
        input, target, meta_data, depth_patch_shape = self.data_loader[it]
        for it in (tqdm(range(n_iters))):
            rgb_cache = []
            depth_cache = []
            self.meta_data = meta_data
            self.shape = depth_patch_shape
            # pred = np.zeros(input.shape[-2:])
            # concatenate probability matrix
            pred = np.zeros([input.shape[-2], input.shape[-1]])
            if self.opt.reconstruction_method == 'gaussian':
                pred = np.zeros([2, input.shape[-2], input.shape[-1]])
                pred_sem = np.zeros(
                    [self.opt.n_classes, input.shape[-2], input.shape[-1]])
            else:
                pred_sem = np.zeros([input.shape[-2], input.shape[-1]])
            target_reconstructed = np.zeros(input.shape[-2:])

            # input is a tensor
            rgb_cache = [
                crop for crop in self.sliding_window_coords(
                    input, test_stride, imageSize)
            ]
            depth_cache = [
                crop for crop in self.sliding_window_coords(
                    target, test_stride, imageSize)
            ]  # don't need both

            for input_crop_tuple, target_crop_tuple in tqdm(
                    zip(rgb_cache, depth_cache), total=len(rgb_cache)):
                input_crop, (x1, x2, y1, y2) = input_crop_tuple
                input_crop = self.get_variable(input_crop)
                # self.complete_padding = True
                # ToDo: Deal with padding later
                if self.opt.use_padding:
                    from torch.nn import ReflectionPad2d

                    self.opt.padding = self.get_padding_image_dims(input_crop)

                    input_crop = ReflectionPad2d(self.opt.padding)(input_crop)
                    (pwl, pwr, phu, phb) = self.opt.padding
                    # target_crop = ReflectionPad2d(self.opt.padding)(target_crop)

                with torch.no_grad():
                    outG, _ = self.netG.forward(input_crop)

                out_numpy = outG.data[0].cpu().float().numpy()
                if self.opt.reconstruction_method == 'concatenation':
                    if self.opt.use_padding:
                        pred[y1:y2,
                             x1:x2] = (out_numpy[0])[phu:phu +
                                                     self.opt.imageSize[1],
                                                     pwl:pwl +
                                                     self.opt.imageSize[0]]
                    else:
                        pred[y1:y2, x1:x2] = out_numpy[0]
                elif self.opt.reconstruction_method == 'gaussian':
                    pred[0, y1:y2,
                         x1:x2] += np.multiply(out_numpy[0], prob_matrix)
                    pred[1, y1:y2, x1:x2] += prob_matrix

                target_reconstructed[y1:y2, x1:x2] = target_crop_tuple[0]

            if self.opt.reconstruction_method == 'gaussian':
                gaussian = pred[1]
                pred = np.divide(pred[0], gaussian)
                # pred_sem = np.divide(pred_sem, gaussian)

                # st()
                if self.opt.dfc_preprocessing == 0:
                    # resize outputs
                    pred = np.array(
                        Image.fromarray(pred).resize(
                            (pred.shape[1] // 10, pred.shape[0] // 10),
                            Image.BILINEAR))
                    target_reconstructed = np.array(
                        Image.fromarray(target_reconstructed).resize(
                            (target_reconstructed.shape[1] // 10,
                             target_reconstructed.shape[0] // 10),
                            Image.BILINEAR))
                error_list.append(np.abs(pred - target_reconstructed))
                outG_list.append(np.abs(pred))

        return error_list, outG_list, target_reconstructed

    def test_raster_notarget(self, load_data):
        from dataloader.dataset_bank import dataset_dfc
        print('Test phase using {} split.'.format(self.opt.test_split))
        phase = 'test'

        imageSize = self.opt.imageSize if len(
            self.opt.imageSize) == 2 else self.opt.imageSize * 2
        test_stride = self.opt.test_stride if len(
            self.opt.test_stride) == 2 else self.opt.test_stride * 2
        input_list = dataset_dfc(self.opt.dataroot,
                                 data_split=self.opt.test_split,
                                 phase='test',
                                 model=self.opt.model)

        # Sanity check : raise an error if some files do not exist
        for f in input_list:
            if not os.path.isfile(f):
                raise KeyError('{} is not a file !'.format(f))

        data_loader = [(rgb) for rgb in load_data(
            input_list, phase, dfc_preprocessing=self.opt.dfc_preprocessing)]

        self.netG.eval()

        # create a matrix with a gaussian distribution to be the weights during reconstruction
        prob_matrix = self.gaussian_kernel(imageSize[0], imageSize[1])

        for it, input in enumerate(tqdm(data_loader)):
            rgb_cache = []

            pred_gaussian = np.zeros([input.shape[-2], input.shape[-1]])
            if self.opt.reconstruction_method == 'gaussian':
                pred_sem = np.zeros(
                    [self.opt.n_classes, input.shape[-2], input.shape[-1]])
            else:
                pred_sem = np.zeros([input.shape[-2], input.shape[-1]])
            target_reconstructed = np.zeros(input.shape[-2:])

            # input is a tensor
            rgb_cache = [
                crop for crop in self.sliding_window_coords(
                    input, test_stride, imageSize)
            ]

            for input_crop_tuple in tqdm(rgb_cache, total=len(rgb_cache)):
                input_crop, (x1, x2, y1, y2) = input_crop_tuple
                input_crop = self.get_variable(input_crop)

                # ToDo: Deal with padding later
                if self.opt.use_padding:
                    from torch.nn import ReflectionPad2d

                    self.opt.padding = self.get_padding_image_dims(input_crop)

                    input_crop = ReflectionPad2d(self.opt.padding)(input_crop)
                    (pwl, pwr, phu, phb) = self.opt.padding
                    # target_crop = ReflectionPad2d(self.opt.padding)(target_crop)

                with torch.no_grad():
                    outG_sem = self.netG.forward(input_crop)

                if self.opt.reconstruction_method == 'gaussian':
                    outG_sem_prob = nn.Sigmoid()(outG_sem)
                    seg_map = outG_sem_prob.cpu().data[0].numpy()
                    pred_sem[:, y1:y2,
                             x1:x2] += np.multiply(seg_map, prob_matrix)
                    pred_gaussian[y1:y2, x1:x2] += prob_matrix
                else:
                    pred_sem[y1:y2,
                             x1:x2] = np.argmax(outG_sem.cpu().data[0].numpy(),
                                                axis=0)

                # visualize
                visuals = OrderedDict([
                    ('input', input_crop.data),
                    ('out_sem',
                     np.argmax(outG_sem.cpu().data[0].numpy(), axis=0))
                ])
                self.display_test_results(visuals)

            if self.opt.save_samples:
                if self.opt.reconstruction_method == 'gaussian':
                    pred_sem = np.divide(pred_sem, pred_gaussian)
                self.save_raster_images_semantics_only(input,
                                                       pred_sem,
                                                       index=it + 1,
                                                       phase='test')

    def test_raster_target(self, load_data):
        from dataloader.dataset_bank import dataset_dfc
        print('Test phase using {} split.'.format(self.opt.test_split))
        phase = 'test'

        use_semantics = self.opt.use_semantics
        self.opt.use_semantics = False
        # self.augmentation = augmentation
        imageSize = self.opt.imageSize if len(
            self.opt.imageSize) == 2 else self.opt.imageSize * 2
        test_stride = self.opt.test_stride if len(
            self.opt.test_stride) == 2 else self.opt.test_stride * 2
        input_list, target_path = dataset_dfc(
            self.opt.dataroot,
            data_split=self.opt.test_split,
            phase='test',
            model=self.opt.model,
            which_raster=self.opt.which_raster)

        # Sanity check : raise an error if some files do not exist
        for f in input_list + target_path:
            if not os.path.isfile(f):
                raise KeyError('{} is not a file !'.format(f))

        data_loader = [(rgb, depth, meta, depth_patch_shape)
                       for rgb, depth, meta, depth_patch_shape in load_data(
                           input_list,
                           target_path,
                           phase,
                           self.opt.dfc_preprocessing,
                           which_raster=self.opt.which_raster,
                           use_semantics=False,
                           save_semantics=self.opt.save_semantics)
                       ]  # false because we do not have the GT
        # no error in save semantics, same value to both variables
        self.netG.eval()

        # if self.opt.normalize:
        #     from dataloader.dataset_raster import get_min_max
        #     import rasterio
        #     max_v, min_v = get_min_max([rasterio.open(path) for path in target_path], self.opt.which_raster)

        # create a matrix with a gaussian distribution to be the weights during reconstruction
        prob_matrix = self.gaussian_kernel(imageSize[0], imageSize[1])
        # st()
        time_array = np.zeros(len(data_loader))

        for it, (input, target, meta_data,
                 depth_patch_shape) in enumerate(tqdm(data_loader)):
            rgb_cache = []
            depth_cache = []
            start = time.time()

            # pred = np.zeros(input.shape[-2:])
            # concatenate probability matrix
            pred = np.zeros([input.shape[-2], input.shape[-1]])
            if self.opt.reconstruction_method == 'gaussian':
                pred = np.zeros([2, input.shape[-2], input.shape[-1]])
                pred_sem = np.zeros(
                    [self.opt.n_classes, input.shape[-2], input.shape[-1]])
            else:
                pred_sem = np.zeros([input.shape[-2], input.shape[-1]])
            target_reconstructed = np.zeros(input.shape[-2:])

            # input is a tensor
            rgb_cache = [
                crop for crop in self.sliding_window_coords(
                    input, test_stride, imageSize)
            ]
            depth_cache = [
                crop for crop in self.sliding_window_coords(
                    target, test_stride, imageSize)
            ]  # don't need both
            # import cProfile
            # torch.cuda.synchronize()
            for input_crop_tuple, target_crop_tuple in tqdm(
                    zip(rgb_cache, depth_cache), total=len(rgb_cache)):
                # cp = cProfile.Profile()
                # cp.enable()
                # for input_crop_tuple, target_crop_tuple in zip(rgb_cache, depth_cache):
                input_crop, (x1, x2, y1, y2) = input_crop_tuple
                input_crop = self.get_variable(input_crop)

                with torch.no_grad():
                    if 'multitask' in self.opt.model:
                        outG, outG_sem = self.netG.forward(input_crop)
                    else:
                        outG = self.netG.forward(input_crop)[0]
                out_numpy = outG.data[0].cpu().float().numpy()
                if self.opt.reconstruction_method == 'concatenation':
                    pred[y1:y2, x1:x2] = out_numpy[0]
                elif self.opt.reconstruction_method == 'gaussian':
                    pred[0, y1:y2,
                         x1:x2] += np.multiply(out_numpy[0], prob_matrix)
                    pred[1, y1:y2, x1:x2] += prob_matrix

                if self.opt.save_semantics:
                    # pred_sem[:,y1:y2,x1:x2] += outG_sem.cpu().data[0].numpy()
                    if self.opt.reconstruction_method == 'gaussian':
                        # seg_map = np.argmax(outG_sem.cpu().data[0].numpy(), axis=0)
                        # pred_sem[y1:y2,x1:x2] += np.multiply(seg_map, prob_matrix)
                        outG_sem_prob = nn.Sigmoid()(outG_sem)
                        seg_map = outG_sem_prob.cpu().data[0].numpy()
                        pred_sem[:, y1:y2,
                                 x1:x2] += np.multiply(seg_map, prob_matrix)
                    else:
                        pred_sem[y1:y2, x1:x2] = np.argmax(
                            outG_sem.cpu().data[0].numpy(), axis=0)

                    # visualize takes a lot of time
                    # visuals = OrderedDict([('input', input_crop.data),
                    #             # ('gt', target_crop.data),
                    #             ('output', outG),
                    #             # ('gt_sem', self.target_sem.data[0].cpu().float().numpy()),
                    #             ('out_sem', np.argmax(outG_sem.cpu().data[0].numpy(), axis=0))
                    #             ])
                    # self.display_test_results(visuals)

                target_reconstructed[y1:y2, x1:x2] = target_crop_tuple[0]
                # break
                # st()
                # cp.disable()
                # cp.print_stats()
                # st()
            end = time.time()
            time_array[it] = end - start
            print('Time in seconds: {}'.format(end - start))
            t_time = end - start
            day = t_time // (24 * 3600)
            t_time = t_time % (24 * 3600)
            hour = t_time // 3600
            t_time %= 3600
            minutes = t_time // 60
            t_time %= 60
            seconds = t_time
            print("d:h:m:s-> %d:%d:%d:%d" % (day, hour, minutes, seconds))

            if self.opt.save_samples:
                if self.opt.reconstruction_method == 'gaussian':
                    gaussian = pred[1]
                    pred = np.divide(pred[0], gaussian)
                    pred_sem = np.divide(pred_sem, gaussian)
                    # if self.opt.normalize:
                    #     target_reconstructed = (target_reconstructed + 1) / 2
                    #     target_reconstructed = (target_reconstructed * (max_v - min_v)) + min_v
                    #     pred = (pred + 1) / 2
                    #     pred = (pred * (max_v - min_v)) + min_v
                    # print('Target: max[{}] min[{}]'.format(target_reconstructed.max(),target_reconstructed.min()))
                    # print('Pred: max[{}] min[{}]'.format(pred.max(),pred.min()))

                indexplus = 4
                if self.opt.save_semantics:
                    self.save_raster_images_semantics(
                        input, pred, target_reconstructed, pred_sem, meta_data,
                        depth_patch_shape, it + 1 + indexplus, 'test')
                else:
                    self.save_raster_images(input, pred, target_reconstructed,
                                            meta_data, depth_patch_shape,
                                            it + 1 + indexplus, 'test')
                del input, pred, target_reconstructed, gaussian, pred_sem

        print('Test statistics')
        print('Mean and standard deviation in seconds {:.3f} {:.3f}'.format(
            np.mean(time_array), np.std(time_array)))

        print('Saving merged!')
        self.save_merged_rasters('output')
        self.save_merged_rasters('target')
        self.save_merged_rasters('semantics')

    def display_test_results(self, visuals):
        self.visualizer.display_images(visuals, 1)

    def get_padding(self, dim):
        final_dim = (dim // 32 + 1) * 32
        return final_dim - dim

    def sliding_window_coords(self, data, step, window_size):
        from dataloader.dataset_raster import sliding_window
        # data = data.data[0].cpu().float().numpy()
        for x1, x2, y1, y2 in sliding_window(data, step, window_size):
            if len(data.shape) == 2:
                yield (data[y1:y2, x1:x2], [x1, x2, y1, y2])
            else:
                yield (torch.from_numpy(data[:, y1:y2, x1:x2]).unsqueeze(0),
                       [x1, x2, y1, y2])  # why do I have to unsqueeze here?

    def gaussian_kernel(self, width, height, sigma=0.2, mu=0.0):
        x, y = np.meshgrid(np.linspace(-1, 1, height),
                           np.linspace(-1, 1, height))
        d = np.sqrt(x * x + y * y)
        gaussian_k = (np.exp(-(
            (d - mu)**2 / (2.0 * sigma**2)))) / np.sqrt(2 * np.pi * sigma**2)
        return gaussian_k  # / gaussian_k.sum()

    def get_padding_image_dims(self, img):
        # get tensor dimensions
        h, w = img.size()[2:]
        # self.opt.imageSize = (w + 4, h + 4)
        w_pad, h_pad = self.get_padding(w + 4) + 4, self.get_padding(h + 4) + 4

        pwr = w_pad // 2
        pwl = w_pad - pwr
        phb = h_pad // 2
        phu = h_pad - phb

        # pwl, pwr, phu, phb
        return (pwl, pwr, phu, phb)

    def get_padding_image(self, img):
        # get tensor dimensions
        h, w = img.size()[2:]
        self.opt.imageSize = (w, h)
        w_pad, h_pad = self.get_padding(w), self.get_padding(h)

        pwr = w_pad // 2
        pwl = w_pad - pwr
        phb = h_pad // 2
        phu = h_pad - phb

        # pwl, pwr, phu, phb
        return (pwl, pwr, phu, phb)

    def save_height_colormap(self, filename, data, cmap='jet'):
        import matplotlib.pyplot as plt
        plt.switch_backend('agg')
        dpi = 80
        data = data[0, :, :]
        height, width = data.shape
        figsize = width / float(dpi), height / float(dpi)
        # change string
        if 'output' in filename:
            filename = filename.replace('merged_output', 'cmap_merged_output')
        else:
            filename = filename.replace('merged_target', 'cmap_merged_target')
        fig = plt.figure(figsize=figsize)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
        cax = ax.imshow(data,
                        vmax=30,
                        vmin=0,
                        aspect='auto',
                        interpolation='spline16',
                        cmap=cmap)
        ax.set(xlim=[0, width], ylim=[height, 0], aspect=1)

        fig.savefig(filename, dpi=dpi)
        del fig, data

    def save_dsm_as_raster(self, data, filename, meta_data, shape):
        import rasterio
        filename = os.path.join(self.save_samples_path, filename)
        depth_patch = np.expand_dims(np.array(
            Image.fromarray(data).resize(shape, Image.BILINEAR)),
                                     axis=0)

        # if 'output' in filename or 'target' in filename:
        #     try:
        #         self.save_height_colormap(filename, depth_patch, cmap='jet')
        #     except:
        #         pass
        with rasterio.open(filename, "w", **meta_data) as dest:
            if dest.write(depth_patch) == False:
                print('Couldnt save image, sorry')

    def save_raster_images_semantics(self,
                                     input,
                                     output,
                                     target,
                                     semantics,
                                     meta_data,
                                     shape,
                                     index,
                                     phase='train',
                                     out_type='png'):
        from dataloader.dataset_raster import sliding_window
        self.save_raster_images(input, output, target, meta_data[0], shape,
                                index, phase)
        del input, output, target
        import gc
        gc.collect()
        filename = '{}/semantics/semantics_{:04}.tif'.format(
            self.save_samples_path, index)

        if self.opt.reconstruction_method == 'gaussian':
            semantics = np.argmax(semantics, axis=0)

        semantics = np.array(semantics, dtype=np.uint8)
        sem_patch = np.expand_dims(np.array(
            Image.fromarray(semantics, mode='P').resize(shape, Image.NEAREST)),
                                   axis=0)
        del semantics

        import rasterio
        with rasterio.open(filename, "w", **meta_data[1]) as dest:
            if dest.write(sem_patch) == False:
                print('Couldnt save image, sorry')
            # base_stride /= 2
        del sem_patch

    # def save_height_colormap(self, filename, data, cmap='jet'):
    #     import matplotlib.pyplot as plt
    #     plt.switch_backend('agg')
    #     dpi = 80
    #     data = data[0,:,:]
    #     height, width = data.shape
    #     figsize = width / float(dpi), height / float(dpi)
    #     # change string
    #     filename = filename.replace('output_', 'cmap_output_')
    #     fig = plt.figure(figsize=figsize)
    #     ax = fig.add_axes([0, 0, 1, 1])
    #     ax.axis('off')
    #     cax = ax.imshow(data, vmax=30, vmin=0, aspect='auto', interpolation='spline16', cmap=cmap)
    #     ax.set(xlim=[0, width], ylim=[height, 0], aspect=1)

    #     fig.savefig(filename, dpi=dpi)

    def save_raster_images_semantics_only(self,
                                          input,
                                          semantics,
                                          index,
                                          shape=(1192, 1202),
                                          phase='train',
                                          out_type='png'):
        print('Saving semantics...')
        import gc
        gc.collect()

        filename = '{}/semantics/semantics_{:04}.tif'.format(
            self.save_samples_path, index)
        # 1192
        # 1202
        import time
        start = time.time()
        if self.opt.reconstruction_method == 'gaussian':
            semantics = np.argmax(semantics, axis=0)
        end = time.time()
        print('Time to argmax: {}'.format(end - start))

        semantics = np.array(semantics, dtype=np.uint8)
        sem_patch = np.expand_dims(np.array(
            Image.fromarray(semantics, mode='P').resize(shape, Image.NEAREST)),
                                   axis=0)

        import rasterio
        with rasterio.open(filename, "w", **OUT_META_SEM[index - 1]) as dest:
            if dest.write(sem_patch) == False:
                print('Couldnt save image, sorry')

    def save_raster_images(self,
                           input,
                           output,
                           target,
                           meta_data,
                           shape,
                           index,
                           phase='train',
                           out_type='png'):
        # self.save_rgb_raster(input.data, '{}/input/input_{:04}.png'.format(self.save_samples_path, index))
        self.save_dsm_as_raster(output,
                                'output/output_{:04}.tif'.format(index),
                                meta_data, shape)
        self.save_dsm_as_raster(target,
                                'target/target_{:04}.tif'.format(index),
                                meta_data, shape)
        # self.save_dsm_as_raster(target.data, '{}/target/target_{:04}.tif'.format(self.save_samples_path, index), meta_data)

    def get_variable(self, tensor):
        variable = Variable(tensor)
        if self.opt.cuda:
            return variable.cuda()

    def create_save_folders(
            self,
            subfolders=['input', 'target', 'results', 'output', 'semantics']):
        if self.opt.save_samples:
            if self.opt.test:
                self.save_samples_path = os.path.join(
                    'results/{}'.format(self.opt.dataset_name), self.opt.name,
                    self.opt.epoch)
                for subfolder in subfolders:
                    path = os.path.join(self.save_samples_path, subfolder)
                    os.system('mkdir -p {0}'.format(path))