예제 #1
0
    def _write_loss(self, phase, global_step):

        loss_types = self.cfg.LOSS_TYPES

        self.label_show = self.label.data.cpu().numpy()
        self.source_modal_show = self.source_modal

        if phase == 'train':

            self.writer.add_scalar('LR',
                                   self.optimizer_ED.param_groups[0]['lr'],
                                   global_step=global_step)

            if 'CLS' in loss_types:
                self.writer.add_scalar('Seg/TRAIN_CLS_LOSS',
                                       self.loss_meters['TRAIN_CLS_LOSS'].avg,
                                       global_step=global_step)

            if 'SEMANTIC' in loss_types:
                self.writer.add_scalar(
                    'Seg/TRAIN_SEMANTIC_LOSS_2DEPTH',
                    self.loss_meters['TRAIN_SEMANTIC_LOSS_2DEPTH'].avg,
                    global_step=global_step)
                self.writer.add_scalar(
                    'Seg/TRAIN_SEMANTIC_LOSS_2SEG',
                    self.loss_meters['TRAIN_SEMANTIC_LOSS_2SEG'].avg,
                    global_step=global_step)

            if 'PIX2PIX' in loss_types:
                self.writer.add_scalar(
                    'Seg/TRAIN_PIX2PIX_LOSS_2DEPTH',
                    self.loss_meters['TRAIN_PIX2PIX_LOSS_2DEPTH'].avg,
                    global_step=global_step)
                self.writer.add_scalar(
                    'Seg/TRAIN_PIX2PIX_LOSS_2SEG',
                    self.loss_meters['TRAIN_PIX2PIX_LOSS_2SEG'].avg,
                    global_step=global_step)

            self.writer.add_image('Seg/Train_groundtruth_depth',
                                  torchvision.utils.make_grid(
                                      self.target_depth[:6].clone().cpu().data,
                                      3,
                                      normalize=True),
                                  global_step=global_step)
            self.writer.add_image('Seg/Train_predicted_depth',
                                  torchvision.utils.make_grid(
                                      self.gen[0].data[:6].clone().cpu().data,
                                      3,
                                      normalize=True),
                                  global_step=global_step)
            self.writer.add_image('Seg/Train_groundtruth_seg',
                                  torchvision.utils.make_grid(
                                      self.target_seg[:6].clone().cpu().data,
                                      3,
                                      normalize=True),
                                  global_step=global_step)
            self.writer.add_image('Seg/Train_predicted_seg',
                                  torchvision.utils.make_grid(
                                      self.gen[1].data[:6].clone().cpu().data,
                                      3,
                                      normalize=True),
                                  global_step=global_step)

            self.writer.add_image(
                'Seg/Train_image',
                torchvision.utils.make_grid(
                    self.source_modal_show[:6].clone().cpu().data,
                    3,
                    normalize=True),
                global_step=global_step)
            if 'CLS' in loss_types:
                self.writer.add_image(
                    'Seg/Train_predicted_label',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(self.Train_predicted_label[:6],
                                         ignore=self.cfg.IGNORE_LABEL,
                                         dataset=self.cfg.DATASET)),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)
                self.writer.add_image(
                    'Seg/Train_ground_label',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(self.label_show[:6],
                                         ignore=self.cfg.IGNORE_LABEL,
                                         dataset=self.cfg.DATASET)),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)

        if phase == 'test':

            self.writer.add_scalar('Seg/VAL_CLS_MEAN_IOU',
                                   float(self.val_iou.mean()) * 100.0,
                                   global_step=global_step)
예제 #2
0
    def _write_loss(self, phase, global_step):

        loss_types = self.cfg.LOSS_TYPES

        self.label_show = self.label.data.cpu().numpy()
        self.source_modal_show = self.source_modal

        if phase == 'train':

            self.writer.add_scalar('LR',
                                   self.optimizer_ED.param_groups[0]['lr'],
                                   global_step=global_step)

            if 'CLS' in loss_types:
                self.writer.add_scalar('Seg/TRAIN_CLS_LOSS',
                                       self.loss_meters['TRAIN_CLS_LOSS'].avg,
                                       global_step=global_step)
                # self.writer.add_scalar('TRAIN_CLS_ACC', self.loss_meters['TRAIN_CLS_ACC'].avg*100.0,
                #                        global_step=global_step)
                # self.writer.add_scalar('TRAIN_CLS_MEAN_IOU', float(self.train_iou.mean())*100.0,
                #                        global_step=global_step)

            if 'SEMANTIC' in loss_types and self.using_semantic_branch:
                self.writer.add_scalar(
                    'Seg/TRAIN_SEMANTIC_LOSS_2DEPTH',
                    self.loss_meters['TRAIN_SEMANTIC_LOSS_2DEPTH'].avg,
                    global_step=global_step)
                self.writer.add_scalar(
                    'Seg/TRAIN_SEMANTIC_LOSS_2SEG',
                    self.loss_meters['TRAIN_SEMANTIC_LOSS_2SEG'].avg,
                    global_step=global_step)

                self.writer.add_image(
                    'Seg/Train_groundtruth_depth',
                    torchvision.utils.make_grid(
                        self.target_depth[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)
                self.writer.add_image(
                    'Seg/Train_predicted_depth',
                    torchvision.utils.make_grid(
                        self.gen[0].data[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)
                self.writer.add_image(
                    'Seg/Train_groundtruth_seg',
                    torchvision.utils.make_grid(
                        self.target_seg[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)
                self.writer.add_image(
                    'Seg/Train_predicted_seg',
                    torchvision.utils.make_grid(
                        self.gen[1].data[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)

            self.writer.add_image(
                'Seg/Train_image',
                torchvision.utils.make_grid(
                    self.source_modal_show[:6].clone().cpu().data,
                    3,
                    normalize=True),
                global_step=global_step)
            if 'CLS' in loss_types:
                self.writer.add_image(
                    'Seg/Train_predicted_label',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(self.Train_predicted_label[:6])),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)
                # torchvision.utils.make_grid(util.color_label(torch.max(self.Train_predicted_label[:6], 1)[1]+1), 3, normalize=False,range=(0, 255)), global_step=global_step)
                self.writer.add_image(
                    'Seg/Train_ground_label',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(self.label_show[:6])),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)
                # torchvision.utils.make_grid(util.color_label(self.label_show[:6]), 3, normalize=False,range=(0, 255)), global_step=global_step)
            # if self.upsample:
            #     self.writer.add_image('Gen_depth', torchvision.utils.make_grid(self.gen[:3].clone().cpu().data, 3,
            #                                                                      normalize=True),
            #                           global_step=global_step)
            #     self.writer.add_image('Train_3_Target',
            #                           torchvision.utils.make_grid(self.target_modal_show[:3].clone().cpu().data, 3,
            # normalize=True), global_step=global_step)

        if phase == 'test':

            # self.writer.add_image('Seg/Val_image',
            #                       torchvision.utils.make_grid(self.source_modal_show[:6].clone().cpu().data, 3,
            #                                                   normalize=True), global_step=global_step)
            #
            # self.writer.add_image('Seg/Val_predicted_label',
            #                       torchvision.utils.make_grid(torch.from_numpy(util.color_label(self.val_predicted_label[:6])), 3, normalize=True,range=(0, 255)), global_step=global_step)
            #                       # torchvision.utils.make_grid(util.color_label(torch.max(self.val_predicted_label[:3], 1)[1]+1), 3, normalize=False,range=(0, 255)), global_step=global_step)
            # self.writer.add_image('Seg/Val_ground_label',
            #                       torchvision.utils.make_grid(torch.from_numpy(util.color_label(self.label_show[:6])), 3, normalize=True,range=(0, 255)), global_step=global_step)

            # self.writer.add_scalar('Seg/VAL_CLS_LOSS', self.loss_meters['VAL_CLS_LOSS'].avg,
            #                        global_step=global_step)
            # self.writer.add_scalar('Seg/VAL_CLS_ACC', self.loss_meters['VAL_CLS_ACC'].avg*100.0,
            #                        global_step=global_step)
            self.writer.add_scalar('Seg/VAL_CLS_MEAN_IOU',
                                   float(self.val_iou.mean()) * 100.0,
                                   global_step=global_step)
    def _write_loss(self, phase, global_step):

        loss_types = self.cfg.LOSS_TYPES

        self.label_show = self.label.data.cpu().numpy()
        self.source_modal_show = self.source_modal
        self.target_modal_show = self.target_modal

        if phase == 'train':

            self.writer.add_scalar('Seg/LR',
                                   self.optimizer_ED.param_groups[0]['lr'],
                                   global_step=global_step)

            if 'CLS' in loss_types:
                self.writer.add_scalar('Seg/TRAIN_CLS_LOSS',
                                       self.loss_meters['TRAIN_CLS_LOSS'].avg,
                                       global_step=global_step)
                # self.writer.add_scalar('TRAIN_CLS_ACC', self.loss_meters['TRAIN_CLS_ACC'].avg*100.0,
                #                        global_step=global_step)
                # self.writer.add_scalar('TRAIN_CLS_MEAN_IOU', float(self.train_iou.mean())*100.0,
                #                        global_step=global_step)

            if self.trans:

                if 'SEMANTIC' in self.cfg.LOSS_TYPES:
                    self.writer.add_scalar(
                        'Seg/TRAIN_SEMANTIC_LOSS',
                        self.loss_meters['TRAIN_SEMANTIC_LOSS'].avg,
                        global_step=global_step)
                if 'PIX2PIX' in self.cfg.LOSS_TYPES:
                    self.writer.add_scalar(
                        'Seg/TRAIN_PIX2PIX_LOSS',
                        self.loss_meters['TRAIN_PIX2PIX_LOSS'].avg,
                        global_step=global_step)

                if isinstance(self.target_modal, list):
                    for i, (gen, target) in enumerate(
                            zip(self.gen, self.target_modal)):
                        self.writer.add_image(
                            'Seg/2_Train_Gen_' +
                            str(self.cfg.FINE_SIZE / pow(2, i)),
                            torchvision.utils.make_grid(
                                gen[:6].clone().cpu().data, 3, normalize=True),
                            global_step=global_step)
                        self.writer.add_image(
                            'Seg/3_Train_Target_' +
                            str(self.cfg.FINE_SIZE / pow(2, i)),
                            torchvision.utils.make_grid(
                                target[:6].clone().cpu().data,
                                3,
                                normalize=True),
                            global_step=global_step)
                else:
                    self.writer.add_image(
                        'Seg/Train_target',
                        torchvision.utils.make_grid(
                            self.target_modal_show[:6].clone().cpu().data,
                            3,
                            normalize=True),
                        global_step=global_step)
                    self.writer.add_image(
                        'Seg/Train_gen',
                        torchvision.utils.make_grid(
                            self.gen.data[:6].clone().cpu().data,
                            3,
                            normalize=True),
                        global_step=global_step)

            self.writer.add_image(
                'Seg/Train_image',
                torchvision.utils.make_grid(
                    self.source_modal_show[:6].clone().cpu().data,
                    3,
                    normalize=True),
                global_step=global_step)
            if 'CLS' in loss_types:
                self.writer.add_image(
                    'Seg/Train_predicted',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(self.Train_predicted_label[:6],
                                         ignore=self.cfg.IGNORE_LABEL)),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)
                # torchvision.utils.make_grid(util.color_label(torch.max(self.Train_predicted_label[:6], 1)[1]+1), 3, normalize=False,range=(0, 255)), global_step=global_step)
                self.writer.add_image(
                    'Seg/Train_label',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(self.label_show[:6],
                                         ignore=self.cfg.IGNORE_LABEL)),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)

        if phase == 'test':
            self.writer.add_image(
                'Seg/Val_image',
                torchvision.utils.make_grid(
                    self.source_modal_show[:6].clone().cpu().data,
                    3,
                    normalize=True),
                global_step=global_step)

            self.writer.add_image('Seg/Val_predicted',
                                  torchvision.utils.make_grid(torch.from_numpy(
                                      util.color_label(
                                          self.val_predicted_label[:6],
                                          ignore=self.cfg.IGNORE_LABEL)),
                                                              3,
                                                              normalize=True,
                                                              range=(0, 255)),
                                  global_step=global_step)
            # torchvision.utils.make_grid(util.color_label(torch.max(self.val_predicted_label[:3], 1)[1]+1), 3, normalize=False,range=(0, 255)), global_step=global_step)
            self.writer.add_image('Seg/Val_label',
                                  torchvision.utils.make_grid(torch.from_numpy(
                                      util.color_label(
                                          self.label_show[:6],
                                          ignore=self.cfg.IGNORE_LABEL)),
                                                              3,
                                                              normalize=True,
                                                              range=(0, 255)),
                                  global_step=global_step)

            self.writer.add_scalar('Seg/VAL_CLS_LOSS',
                                   self.loss_meters['VAL_CLS_LOSS'].avg,
                                   global_step=global_step)
            # self.writer.add_scalar('Seg/VAL_CLS_ACC', self.loss_meters['VAL_CLS_ACC'].avg*100.0,
            #                        global_step=global_step)
            self.writer.add_scalar('Seg/VAL_CLS_MEAN_IOU',
                                   float(self.val_iou.mean()) * 100.0,
                                   global_step=global_step)
    def write_loss(self, phase, global_step=1):

        loss_types = self.cfg.LOSS_TYPES
        task = self.cfg.TASK_TYPE

        if self.phase == 'train':
            label_show = self.label.data.cpu().numpy()
        else:
            label_show = np.uint8(self.label.data.cpu())

        source_modal_show = self.source_modal
        target_modal_show = self.target_modal

        if phase == 'train':

            self.writer.add_image(task + '/Train_image',
                                  torchvision.utils.make_grid(
                                      source_modal_show[:6].clone().cpu().data,
                                      3,
                                      normalize=True),
                                  global_step=global_step)
            self.writer.add_scalar(task + '/LR',
                                   self.optimizer.param_groups[0]['lr'],
                                   global_step=global_step)

            if 'CLS' in loss_types:
                self.writer.add_scalar(task + '/TRAIN_CLS_LOSS',
                                       self.loss_meters['TRAIN_CLS_LOSS'].avg,
                                       global_step=global_step)
                if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA:
                    self.writer.add_scalar(
                        task + '/TRAIN_CLS_LOSS_COMPL',
                        self.loss_meters['TRAIN_CLS_LOSS_COMPL'].avg,
                        global_step=global_step)
                    self.writer.add_image(
                        task + '/Compl_image',
                        torchvision.utils.make_grid(
                            self.result['compl_source'][:6].clone().cpu().data,
                            3,
                            normalize=True),
                        global_step=global_step)
                # self.writer.add_scalar('TRAIN_CLS_ACC', self.loss_meters['TRAIN_CLS_ACC'].avg*100.0,
                #                        global_step=global_step)
                # self.writer.add_scalar('TRAIN_CLS_MEAN_IOU', float(self.train_iou.mean())*100.0,
                #                        global_step=global_step)

            if self.trans and not self.cfg.MULTI_MODAL:

                if 'SEMANTIC' in self.cfg.LOSS_TYPES:
                    self.writer.add_scalar(
                        task + '/TRAIN_SEMANTIC_LOSS',
                        self.loss_meters['TRAIN_SEMANTIC_LOSS'].avg,
                        global_step=global_step)
                if 'PIX2PIX' in self.cfg.LOSS_TYPES:
                    self.writer.add_scalar(
                        task + '/TRAIN_PIX2PIX_LOSS',
                        self.loss_meters['TRAIN_PIX2PIX_LOSS'].avg,
                        global_step=global_step)

                self.writer.add_image(task + '/Train_gen',
                                      torchvision.utils.make_grid(
                                          self.gen.data[:6].clone().cpu().data,
                                          3,
                                          normalize=True),
                                      global_step=global_step)
                self.writer.add_image(
                    task + '/Train_image',
                    torchvision.utils.make_grid(
                        source_modal_show[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)
                # if isinstance(self.target_modal, list):
                #     for i, (gen, target) in enumerate(zip(self.gen, self.target_modal)):
                #         self.writer.add_image('Seg/2_Train_Gen_' + str(self.cfg.FINE_SIZE / pow(2, i)),
                #                               torchvision.utils.make_grid(gen[:6].clone().cpu().data, 3,
                #                                                           normalize=True),
                #                               global_step=global_step)
                #         self.writer.add_image('Seg/3_Train_Target_' + str(self.cfg.FINE_SIZE / pow(2, i)),
                #                               torchvision.utils.make_grid(target[:6].clone().cpu().data, 3,
                #                                                           normalize=True),
                #                               global_step=global_step)
                # else:
                self.writer.add_image(
                    task + '/Train_target',
                    torchvision.utils.make_grid(
                        target_modal_show[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)

            if 'CLS' in loss_types and self.cfg.TASK_TYPE == 'segmentation':
                train_pred = self.result['cls'].data.max(1)[1].cpu().numpy()
                self.writer.add_image(
                    task + '/Train_predicted',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(train_pred[:6],
                                         ignore=self.cfg.IGNORE_LABEL,
                                         dataset=self.cfg.DATASET)),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)
                self.writer.add_image(
                    task + '/Train_label',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(label_show[:6],
                                         ignore=self.cfg.IGNORE_LABEL,
                                         dataset=self.cfg.DATASET)),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)

        elif phase == 'test':

            self.writer.add_image(task + '/Val_image',
                                  torchvision.utils.make_grid(
                                      source_modal_show[:6].clone().cpu().data,
                                      3,
                                      normalize=True),
                                  global_step=global_step)
            # self.writer.add_image('Seg/Val_image',
            #                       torchvision.utils.make_grid(source_modal_show[:6].clone().cpu().data, 3,
            #                                                   normalize=True), global_step=global_step)
            #
            # self.writer.add_image('Seg/Val_predicted',
            #                       torchvision.utils.make_grid(
            #                           torch.from_numpy(util.color_label(self.pred[:6], ignore=self.cfg.IGNORE_LABEL,
            #                                                             dataset=self.cfg.DATASET)), 3,
            #                           normalize=True, range=(0, 255)), global_step=global_step)
            # self.writer.add_image('Seg/Val_label',
            #                       torchvision.utils.make_grid(torch.from_numpy(
            #                           util.color_label(label_show[:6], ignore=self.cfg.IGNORE_LABEL,
            #                                            dataset=self.cfg.DATASET)),
            #                           3, normalize=True, range=(0, 255)),
            #                       global_step=global_step)

            if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA:
                self.writer.add_image(
                    task + '/Compl_image',
                    torchvision.utils.make_grid(
                        self.result['compl_source'][:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)

            self.writer.add_scalar(task + '/VAL_CLS_ACC',
                                   self.loss_meters['VAL_CLS_ACC'].val * 100.0,
                                   global_step=global_step)
            self.writer.add_scalar(task + '/VAL_CLS_MEAN_ACC',
                                   self.loss_meters['VAL_CLS_MEAN_ACC'].val *
                                   100.0,
                                   global_step=global_step)
            if task == 'segmentation':
                self.writer.add_scalar(
                    task + '/VAL_CLS_MEAN_IOU',
                    self.loss_meters['VAL_CLS_MEAN_IOU'].val * 100.0,
                    global_step=global_step)