示例#1
0
    def _compute_metrics(self, outputs, target_l, target_ul, epoch):
        print('sup')
        predict_np = outputs['sup_pred'].cpu().data.numpy()
        target_np = target_l.cpu().data.numpy()
        np.save(self.save_dir + 'test_sup_p.npy', predict_np)
        np.save(self.save_dir + 'test_sup_t.npy', target_np)
        seg_metrics_l = eval_metrics(outputs['sup_pred'], target_l,
                                     self.num_classes, self.ignore_index)
        self._update_seg_metrics(*seg_metrics_l, True)
        seg_metrics_l = self._get_seg_metrics(True)
        self.pixel_acc_l, self.mIoU_l, self.class_iou_l = seg_metrics_l.values(
        )

        if self.mode == 'semi':
            print('unsup')
            predict_np = outputs['unsup_pred'].cpu().data.numpy()
            target_np = target_ul.cpu().data.numpy()
            np.save(self.save_dir + 'test_unsup_p.npy', predict_np)
            np.save(self.save_dir + 'test_unsup_t.npy', target_np)
            seg_metrics_ul = eval_metrics(outputs['unsup_pred'], target_ul,
                                          self.num_classes, self.ignore_index)
            self._update_seg_metrics(*seg_metrics_ul, False)
            seg_metrics_ul = self._get_seg_metrics(False)
            self.pixel_acc_ul, self.mIoU_ul, self.class_iou_ul = seg_metrics_ul.values(
            )
    def _compute_metrics(self, outputs, target_l, target_ul, epoch):
        seg_metrics_l = eval_metrics(outputs['sup_pred'], target_l, self.num_classes, self.str_device)
        self._update_seg_metrics(*seg_metrics_l, target_l.size(0), True)
        seg_metrics_l = self._get_seg_metrics(True)
        self.pixel_acc_l, self.mIoU_l, self.class_iou_l, self.mdice_l, self.class_dice_l = seg_metrics_l.values()

        if self.mode == 'semi':
            seg_metrics_ul = eval_metrics(outputs['unsup_pred'], target_ul, self.num_classes, self.str_device)
            self._update_seg_metrics(*seg_metrics_ul, target_ul.size(0), False)
            seg_metrics_ul = self._get_seg_metrics(False)
            self.pixel_acc_ul, self.mIoU_ul, self.class_iou_ul, self.mdice_ul, self.class_dice_ul = seg_metrics_ul.values()
示例#3
0
    def forward(self, x, labels=None, return_output=False):
        x = self.enc1(x)
        x = self.enc2(x)
        down_x = []
        for b in self.enc_blocks:
            x = b(x)
            down_x.append(x)
        for i, b in enumerate(self.dec_blocks):
            x = b(x)
            if i < len(self.dec_blocks) - 1:
                dx = down_x[-i - 2]
                x = torch.cat([dx, x[:, :, :dx.shape[2], :dx.shape[3]]], dim=1)
        x = self.dec1(x)
        x = self.dec2(x)
        x = self.dec_upsample(x)
        x = self.final(x)

        output = x
        if not hasattr(self, 'loss'):
            return output

        loss = self.loss(output, labels)
        seg_metrics = eval_metrics(output, labels, output.shape[1])

        if not return_output:
            return loss, seg_metrics
        else:
            return loss, output, seg_metrics
示例#4
0
 def _compute_metrics(self, outputs, target_l, epoch):
     seg_metrics_l = eval_metrics(outputs['sup_pred'], target_l,
                                  self.num_classes, self.ignore_index)
     self._update_seg_metrics(*seg_metrics_l, True)
     seg_metrics_l = self._get_seg_metrics(True)
     self.pixel_acc_l, self.mIoU_l, self.class_iou_l = seg_metrics_l.values(
     )
     '''
    def _compute_metrics(self, outputs, target_l, target_ul, epoch):
        seg_metrics_l = eval_metrics(outputs['sup_pred'], target_l,
                                     self.num_classes, self.ignore_index)

        if self.gpu == 0:
            self._update_seg_metrics(*seg_metrics_l, True)
            seg_metrics_l = self._get_seg_metrics(True)
            self.pixel_acc_l, self.mIoU_l, self.class_iou_l = seg_metrics_l.values(
            )

        if 'unsup_pred' in outputs:
            seg_metrics_ul = eval_metrics(outputs['unsup_pred'], target_ul,
                                          self.num_classes, self.ignore_index)

            if self.gpu == 0:
                self._update_seg_metrics(*seg_metrics_ul, False)
                seg_metrics_ul = self._get_seg_metrics(False)
                self.pixel_acc_ul, self.mIoU_ul, self.class_iou_ul = seg_metrics_ul.values(
                )
示例#6
0
    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        loss = self.loss(output, target)
        self.total_loss.update(loss.item())
        seg_metrics = eval_metrics(output, target, self.num_classes)
        self._update_seg_metrics(*seg_metrics)

        pixAcc, mIoU, _ = self._get_seg_metrics().values()

        log_str = 'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'.format(
            self.current_epoch, self.total_loss.average, pixAcc, mIoU)
        self.log('eval_log', log_str, prog_bar=True, logger=False)

        return loss
示例#7
0
    def training_step(self, batch, batch_idx):

        data, target = batch
        output = self.model(data)
        loss = self.loss(output, target)

        self.total_loss.update(loss.item())

        seg_metrics = eval_metrics(output, target, self.num_classes)
        self._update_seg_metrics(*seg_metrics)
        pixAcc, mIoU, _ = self._get_seg_metrics().values()

        log_str = 'TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.2f} |'.format(
            self.current_epoch, self.total_loss.average, pixAcc, mIoU)
        self.log('train_log', log_str, prog_bar=True, logger=False)
        #print(log_str)

        return loss
def main():
    # get the argument from parser
    args = parse_arguments()

    # CONFIG -> assert if config is here
    assert args.config
    config = json.load(open(args.config))

    # DATA
    testdataset = base.testDataset(args.site)
    loader = DataLoader(testdataset, batch_size=1, shuffle=False, num_workers=0)
    num_classes = testdataset.num_classes

    # MODEL
    config['model']['supervised'] = True; config['model']['semi'] = False
    encoder = models.model.Encoder(True)
    model = models.model.CCT(encoder, num_classes=num_classes, conf=config['model'], testing=True)
    map_location = args.map
    checkpoint = torch.load(args.model, map_location)

    if map_location == 'cpu':
        for key in list(checkpoint['state_dict'].keys()):
            if 'module.' in key:
                checkpoint['state_dict'][key.replace('module.', '')] = checkpoint['state_dict'][key]
                del checkpoint['state_dict'][key]

    try:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    except Exception as e:
        print(f'Some modules are missing: {e}')
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.float()
    model.eval()
    if args.map == 'gpu':
        model.cuda()

    check_directory(args.site, args.experiment)

    # LOOP OVER THE DATA
    tbar = tqdm(loader, ncols=100)

    total_loss_val = AverageMeter()
    total_inter, total_union = 0, 0
    total_correct, total_label = 0, 0
    total_dice = 0
    count = 0

    for index, data in enumerate(tbar):
        image, label, image_id = data
        if args.map == 'gpu':
            image = image.cuda()

        # PREDICT
        with torch.no_grad():
            output = model(image)
            correct, labeled, inter, union, dice = eval_metrics(output, label, num_classes, args.map)
            total_inter, total_union = total_inter + inter, total_union + union
            total_correct, total_label = total_correct + correct, total_label + labeled
            total_dice = ((count * total_dice) + (dice * output.size(0))) / (count + output.size(0))
            count += output.size(0)
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            mdice = total_dice.mean()
            seg_metrics = {"Pixel_Accuracy": np.round(pixAcc, 3), "Mean_IoU": np.round(mIoU, 3),
                           "Mean_dice": np.round(mdice, 3),
                           "Class_IoU": dict(zip(range(num_classes), np.round(IoU, 3))),
                           "Class_dice": dict(zip(range(num_classes), np.round(total_dice, 3)))}
            tbar.set_description('EVAL | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} Mean Dice {:.2f} |'.format(
                total_loss_val.average, pixAcc, mIoU, mdice))
            output = torch.argmax(output, dim=1)
        prediction = output.numpy()
        label = label.numpy()
        predictions = batch_scale(prediction)
        labels = batch_scale(label)

        if args.overlay:
            prediction_contours = batch_contour(predictions)
            label_contours = batch_contour(labels)

        # SAVE RESULTS
        for i in range(predictions.shape[0]):
            prediction_im = PIL.Image.fromarray(predictions[i])
            prediction_im.save(f'outputs/{args.site}/{args.experiment}/{image_id[i]}_prediction.png')
            label_im = PIL.Image.fromarray(labels[i])
            label_im.save(f'outputs/{args.site}/{args.experiment}/{image_id[i]}_label.png')

        if args.overlay:
            image = image.numpy()
            image = np.squeeze(image, axis=1)
            image = batch_scale(image)
            palette = contour_palette(testdataset.site)
            for i in range(image.shape[0]):
                image_gt = cv.cvtColor(image[i].copy(), cv.COLOR_GRAY2RGB)
                image_pred = cv.cvtColor(image[i].copy(), cv.COLOR_GRAY2RGB)
                cv.drawContours(image_gt, label_contours[i], -1, (palette[0], palette[1], palette[2]), 1)
                cv.drawContours(image_pred, prediction_contours[i], -1, (palette[0], palette[1], palette[2]), 1)
                cv.imwrite(f'outputs/{args.site}/{args.experiment}/{image_id[i]}_label_overlay.png', image_gt)
                cv.imwrite(f'outputs/{args.site}/{args.experiment}/{image_id[i]}_prediction_overlay.png', image_pred)

    with open(f'outputs/{args.site}/{args.experiment}/test.txt', 'w') as f:
        for k, v in list(seg_metrics.items()):
            f.write("%s\n" % (k + ':' + f'{v}'))
示例#9
0
    def _valid_epoch(self, epoch):
        if self.rank == 0:
            logger.info('\n###### EVALUATION ######')
            wrt_mode = 'val'
            self._reset_metrics()
            val_img = []
            y_true = []
            y_score = []
            y_score_b = []
        self.model.eval()
        tbar = tqdm(self.val_loader, ncols=160)
        with torch.no_grad():

            for batch_idx, (img, gt, Sgt, Lgt, mask) in enumerate(tbar):
                img = img.to(self.rank, non_blocking=True)
                gt = gt.to(self.rank, non_blocking=True)
                mask = mask.to(self.rank, non_blocking=True)
                Sgt = Sgt.to(self.rank, non_blocking=True)
                Lgt = Lgt.to(self.rank, non_blocking=True)
                # LOSS

                with torch.cuda.amp.autocast(enabled=True):
                    if self.gt_num == 1:
                        predict = self.model(img)
                        loss = self.loss(predict, gt)
                    elif self.gt_num == 2:
                        s, predict = self.model(img)
                        loss = self.loss(predict, gt, s, Sgt)
                    else:
                        l, s, predict = self.model(img)
                        loss = self.loss(predict, gt, s, Sgt, l, Lgt)
                if self.rank == 0:
                    self.total_loss.update(loss.item())
                    predict = torch.sigmoid(predict).cpu().detach().numpy()
                    predict_b = np.where(predict >= 0.5, 1, 0)
                    mask = mask.cpu().detach().numpy().ravel()
                    y_true = gt.cpu().detach().numpy().ravel()[mask == 1]
                    y_score = predict.ravel()[mask == 1]
                    y_score_b = predict_b.ravel()[mask == 1]
                    # FOR EVAL and INFO
                    self._update_seg_metrics(*eval_metrics(y_true, y_score_b))
                    metrics = get_metrics(self.tn, self.fp, self.fn, self.tp)
                    tbar.set_description(
                        'EVAL ({}) | Loss: {:.4f} | Acc {:.4f} Pre {:.4f} Sen {:.4f} Spe {:.4f} f1 {:.4f} IOU {:.4f} |'.format(
                            epoch, self.total_loss.average, *metrics.values()))

                    # LIST OF IMAGE TO VIZ (15 images)

                    if batch_idx < 10:
                        val_img.extend([img[0].data.cpu(), gt[0].data.cpu(), torch.tensor(predict_b[0])])
            if self.rank == 0:
                val_img = torch.stack(val_img, 0)
                val_img = make_grid(val_img, nrow=3, padding=2)
                if self.show is True:
                    plt.figure(figsize=(12, 36))
                    plt.imshow(transforms.ToPILImage()(val_img.squeeze(0)).convert('L'), cmap='gray')
                    plt.show()

                # LOGGING & TENSORBOARD
                wrt_step = epoch
                metrics = get_metrics_full(self.tn, self.fp, self.fn, self.tp, y_true, y_score, y_score_b)
                self.writer.add_image(f'{wrt_mode}/inputs_targets_predictions', val_img, wrt_step)
                self.writer.add_scalar(f'{wrt_mode}/loss', self.total_loss.average, wrt_step)
                for k, v in list(metrics.items())[:-1]:
                    self.writer.add_scalar(f'{wrt_mode}/{k}', v, wrt_step)
                log = {
                    'val_loss': self.total_loss.average,
                    **metrics
                }
        return log
示例#10
0
    def eval(self):
        """The function for the meta-evaluate (test) phase."""
        # Load the logs
        trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        # Load meta-test set
        self.test_set = Dataset('test', self.args)
        self.sampler = CategoriesSampler(self.test_set.labeln,
                                         self.args.num_batch,
                                         self.args.way + 1,
                                         self.args.train_query,
                                         self.args.test_query)
        self.loader = DataLoader(dataset=self.test_set,
                                 batch_sampler=self.sampler,
                                 num_workers=8,
                                 pin_memory=True)

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(
                torch.load(self.args.eval_weights)['params'])
        else:
            self.model.load_state_dict(
                torch.load(osp.join(self.args.save_path,
                                    'max_iou' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy(IoU) averager
        ave_acc = Averager()

        # Start meta-test
        K = self.args.way + 1
        N = self.args.train_query
        Q = self.args.test_query

        count = 1
        for i, batch in enumerate(self.loader, 1):
            if torch.cuda.is_available():
                data, labels, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
                labels = batch[1]

            p = K * N
            im_train, im_test = data[:p], data[p:]

            #Adjusting labels for each meta task
            labels = downlabel(labels, K)
            out_train, out_test = labels[:p], labels[p:]

            if (torch.cuda.is_available()):
                im_train = im_train.cuda()
                im_test = im_test.cuda()
                out_train = out_train.cuda()
                out_test = out_test.cuda()

            #Reshaping train set ouput
            Ytr = out_train.reshape(-1)
            Ytr = onehot(Ytr, K)  #One hot encoding for loss

            Yte = out_test.reshape(out_test.shape[0], -1)

            if (torch.cuda.is_available()):
                Ytr = Ytr.cuda()
                Yte = Yte.cuda()
            # Output logits for model
            Gte = self.model(im_train, Ytr, im_test, Yte)
            GteT = torch.transpose(Gte, 1, 2)

            # Calculate meta-train accuracy
            self._reset_metrics()
            seg_metrics = eval_metrics(GteT, Yte, K)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics(K).values()

            ave_acc.add(mIoU)

            #Saving Test Image, Ground Truth Image and Predicted Image
            for j in range(K * Q):

                x1 = im_test[j].detach().cpu()
                y1 = out_test[j].detach().cpu()
                z1 = GteT[j].detach().cpu()
                z1 = torch.argmax(z1, axis=0)

                m = int(math.sqrt(z1.shape[0]))
                z2 = z1.reshape(m, m)

                x = transforms.ToPILImage()(x1).convert("RGB")
                y = Image.fromarray(decode_segmap(y1, K))
                z = Image.fromarray(decode_segmap(z2, K))

                px = self.args.save_image_dir + str(count) + 'a.jpg'
                py = self.args.save_image_dir + str(count) + 'b.png'
                pz = self.args.save_image_dir + str(count) + 'c.png'
                x.save(px)
                y.save(py)
                z.save(pz)
                count = count + 1

        # Test mIoU
        ave_acc = ave_acc.item()
        print("=============================================================")
        print('Average Test mIoU: {:.4f}'.format(ave_acc))
        print("Images Saved!")
        print("=============================================================")
示例#11
0
    def train(self):
        """The function for the meta-train phase."""

        # Set the meta-train log
        #Change when resuming training
        initial_epoch = 25

        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['train_acc'] = []
        trlog['train_iou'] = []

        # Set the meta-val log
        trlog['val_loss'] = []
        trlog['val_acc'] = []
        trlog['val_iou'] = []

        trlog['max_iou'] = 0.2856
        trlog['max_iou_epoch'] = 4

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        K = self.args.way + 1  #included Background as class
        N = self.args.train_query
        Q = self.args.test_query

        # Start meta-train
        for epoch in range(initial_epoch, self.args.max_epoch + 1):
            print(
                '----------------------------------------------------------------------------------------------------------------------------------------------------------'
            )

            # Update learning rate
            self.lr_scheduler.step()

            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()
            train_iou_averager = Averager()

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)

            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, labels, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    labels = batch[1]

                #print(data.shape)
                #print(labels.shape)
                p = K * N
                im_train, im_test = data[:p], data[p:]

                #Adjusting labels for each meta task
                labels = downlabel(labels, K)
                out_train, out_test = labels[:p], labels[p:]
                '''
                print(im_train.shape)
                print(im_test.shape)
                print(out_train.shape)
                print(out_test.shape)
                '''
                if (torch.cuda.is_available()):
                    im_train = im_train.cuda()
                    im_test = im_test.cuda()
                    out_train = out_train.cuda()
                    out_test = out_test.cuda()

                #Reshaping train set ouput
                Ytr = out_train.reshape(-1)
                Ytr = onehot(Ytr, K)  #One hot encoding for loss

                Yte = out_test.reshape(out_test.shape[0], -1)
                if (torch.cuda.is_available()):
                    Ytr = Ytr.cuda()
                    Yte = Yte.cuda()

                # Output logits for model
                Gte = self.model(im_train, Ytr, im_test, Yte)
                GteT = torch.transpose(Gte, 1, 2)

                # Calculate meta-train loss

                #loss = self.CD(GteT,Yte)
                loss = self.FL(GteT, Yte)
                #loss = self.LS(GteT,Yte)

                self._reset_metrics()
                # Calculate meta-train accuracy
                seg_metrics = eval_metrics(GteT, Yte, K)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(K).values()

                # Print loss and accuracy for this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(
                        epoch, loss.item(), pixAcc * 100.0, mIoU))

                # Add loss and accuracy for the averagers
                # Calculate the running averages
                train_loss_averager.add(loss.item())
                train_acc_averager.add(pixAcc)
                train_iou_averager.add(mIoU)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()
            train_iou_averager = train_iou_averager.item()

            #Adding to Tensorboard
            writer.add_scalar('data/train_loss (Meta)',
                              float(train_loss_averager), epoch)
            writer.add_scalar('data/train_acc (Meta)',
                              float(train_acc_averager) * 100.0, epoch)
            writer.add_scalar('data/train_iou (Meta)',
                              float(train_iou_averager), epoch)

            # Update best saved model if validation set is not present and save it
            if (self.args.valdata == 'No'):
                if train_iou_averager > trlog['max_iou']:
                    print("New Best!")
                    trlog['max_iou'] = train_iou_averager
                    trlog['max_iou_epoch'] = epoch
                    self.save_model('max_iou')

                # Save model every 2 epochs
                if epoch % 2 == 0:
                    self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['train_iou'].append(train_iou_averager)

            if epoch % 1 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
                print('Epoch:{}, Average Loss: {:.4f}, Average mIoU: {:.4f}'.
                      format(epoch, train_loss_averager, train_iou_averager))
            """The function for the meta-val phase."""

            if (self.args.valdata == 'Yes'):
                # Start meta-val
                # Set the model to val mode
                self.model.eval()

                # Set averager classes to record training losses and accuracies
                val_loss_averager = Averager()
                val_acc_averager = Averager()
                val_iou_averager = Averager()

                # Using tqdm to read samples from train loader
                tqdm_gen = tqdm.tqdm(self.val_loader)

                for i, batch in enumerate(tqdm_gen, 1):
                    # Update global count number
                    global_count = global_count + 1
                    if torch.cuda.is_available():
                        data, labels, _ = [_.cuda() for _ in batch]
                    else:
                        data = batch[0]
                        labels = batch[1]

                    #print(data.shape)
                    #print(labels.shape)
                    p = K * N
                    im_train, im_test = data[:p], data[p:]

                    #Adjusting labels for each meta task
                    labels = downlabel(labels, K)
                    out_train, out_test = labels[:p], labels[p:]
                    '''
                    print(im_train.shape)
                    print(im_test.shape)
                    print(out_train.shape)
                    print(out_test.shape)
                    '''
                    if (torch.cuda.is_available()):
                        im_train = im_train.cuda()
                        im_test = im_test.cuda()
                        out_train = out_train.cuda()
                        out_test = out_test.cuda()

                    #Reshaping val set ouput
                    Ytr = out_train.reshape(-1)
                    Ytr = onehot(Ytr, K)  #One hot encoding for loss

                    Yte = out_test.reshape(out_test.shape[0], -1)
                    if (torch.cuda.is_available()):
                        Ytr = Ytr.cuda()
                        Yte = Yte.cuda()

                    # Output logits for model
                    Gte = self.model(im_train, Ytr, im_test, Yte)
                    GteT = torch.transpose(Gte, 1, 2)

                    self._reset_metrics()
                    # Calculate meta-train accuracy
                    seg_metrics = eval_metrics(GteT, Yte, K)
                    self._update_seg_metrics(*seg_metrics)
                    pixAcc, mIoU, _ = self._get_seg_metrics(K).values()

                    # Print loss and accuracy for this step
                    tqdm_gen.set_description(
                        'Epoch {}, Val Loss={:.4f} Val Acc={:.4f} Val IoU={:.4f}'
                        .format(epoch, loss.item(), pixAcc * 100.0, mIoU))

                    # Add loss and accuracy for the averagers
                    # Calculate the running averages
                    val_loss_averager.add(loss.item())
                    val_acc_averager.add(pixAcc)
                    val_iou_averager.add(mIoU)

                # Update the averagers
                val_loss_averager = val_loss_averager.item()
                val_acc_averager = val_acc_averager.item()
                val_iou_averager = val_iou_averager.item()

                #Adding to Tensorboard
                writer.add_scalar('data/val_loss (Meta)',
                                  float(val_loss_averager), epoch)
                writer.add_scalar('data/val_acc (Meta)',
                                  float(val_acc_averager) * 100.0, epoch)
                writer.add_scalar('data/val_iou (Meta)',
                                  float(val_iou_averager), epoch)

                # Update best saved model
                if val_iou_averager > trlog['max_iou']:
                    print("New Best (Validation)")
                    trlog['max_iou'] = val_iou_averager
                    trlog['max_iou_epoch'] = epoch
                    self.save_model('max_iou')

                # Save model every 2 epochs
                if epoch % 2 == 0:
                    self.save_model('epoch' + str(epoch))

                # Update the logs
                trlog['val_loss'].append(val_loss_averager)
                trlog['val_acc'].append(val_acc_averager)
                trlog['val_iou'].append(val_iou_averager)

                if epoch % 1 == 0:
                    print('Running Time: {}, Estimated Time: {}'.format(
                        timer.measure(),
                        timer.measure(epoch / self.args.max_epoch)))
                    print(
                        'Epoch:{}, Average Val Loss: {:.4f}, Average Val mIoU: {:.4f}'
                        .format(epoch, val_loss_averager, val_iou_averager))

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

        print(
            '----------------------------------------------------------------------------------------------------------------------------------------------------------'
        )
        writer.close()
示例#12
0
    def _train_epoch(self, epoch):
        self.logger.info('\n')

        self.model.train()
        if self.config['arch']['args']['freeze_bn']:
            if isinstance(self.model, torch.nn.DataParallel):
                self.model.module.freeze_bn()
            else:
                self.model.freeze_bn()
        self.wrt_mode = 'train'

        tic = time.time()
        self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols=130)
        for batch_idx, (data, target, distance) in enumerate(tbar):
            self.data_time.update(time.time() - tic)
            # data, target = data.to(self.device), target.to(self.device)

            # LOSS & OPTIMIZE
            self.kernel_optimizer.zero_grad()
            self.back_optimizer.zero_grad()
            output, distance_loss = self.model(data, distance)
            #print(output.shape)
            assert output.size()[2:] == target.size()[1:]
            assert output.size()[1] == self.num_classes
            seg_loss = self.loss(output, target)
            loss = 5 * seg_loss + distance_loss
            loss.backward()
            self.kernel_optimizer.step()
            self.back_optimizer.step()

            self.kernel_lr_scheduler.step(epoch=epoch - 1)
            self.back_lr_scheduler.step(epoch=epoch - 1)

            self.total_loss.update(loss.item())

            # measure elapsed time
            self.batch_time.update(time.time() - tic)
            tic = time.time()

            # LOGGING & TENSORBOARD
            if batch_idx % self.log_step == 0:
                self.wrt_step = (epoch - 1) * len(
                    self.train_loader) + batch_idx
                self.writer.add_scalar(f'{self.wrt_mode}/loss', loss.item(),
                                       self.wrt_step)
                self.writer.add_scalar(f'{self.wrt_mode}/seg_loss', seg_loss,
                                       self.wrt_step)
                self.writer.add_scalar(f'{self.wrt_mode}/dis_loss',
                                       distance_loss, self.wrt_step)
            # FOR EVAL
            seg_metrics = eval_metrics(output, target, self.num_classes)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics().values()

            # PRINT INFO
            tbar.set_description(
                'TRAIN ({}) | Seg_Loss: {:.3f} | dis_Loss: {:.3f} |Acc {:.2f} mIoU {:.2f} | B {:.2f} D {:.2f} |'
                .format(epoch, 5 * seg_loss, distance_loss, pixAcc, mIoU,
                        self.batch_time.average, self.data_time.average))

        # METRICS TO TENSORBOARD
        seg_metrics = self._get_seg_metrics()
        for k, v in list(seg_metrics.items())[:-1]:
            self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step)

        for i, opt_group in enumerate(self.kernel_optimizer.param_groups):
            self.writer.add_scalar(f'{self.wrt_mode}/kernel_lr_{i}',
                                   opt_group['lr'], self.wrt_step)
            # self.writer.add_scalar(f'{self.wrt_mode}/Momentum_{k}', opt_group['momentum'], self.wrt_step)
        for i, opt_group in enumerate(self.back_optimizer.param_groups):
            self.writer.add_scalar(f'{self.wrt_mode}/back_lr_{i}',
                                   opt_group['lr'], self.wrt_step)
        # RETURN LOSS & METRICS
        log = {'loss': self.total_loss.average, **seg_metrics}

        # if self.lr_scheduler is not None: self.lr_scheduler.step()
        return log
示例#13
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'

        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            kernel_visual = []
            for batch_idx, (data, target, distance) in enumerate(tbar):
                output, distance_loss = self.model(data, distance)
                kernel, coarse_estimation, learned_sdf = self.model.module.kernel_visualization(
                    data)

                seg_loss = self.loss(output, target)
                loss = 5 * seg_loss + distance_loss
                self.total_loss.update(loss.item())

                seg_metrics = eval_metrics(output, target, self.num_classes)
                self._update_seg_metrics(*seg_metrics)

                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    target_np = target.data.cpu().numpy()
                    kernel_np = kernel[:, 1, :, :, ].data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    #print(conved.shape)

                    learned_sdf_np = learned_sdf.data.cpu().numpy()  ##B*k*W*H
                    target_sdf_np = distance.data.cpu().numpy()  ##B*k*W*H

                    coarse_estimation_np = coarse_estimation[:, 1, :, :,
                                                             ].data.cpu(
                                                             ).numpy()

                    val_visual.append([
                        data[0].data.cpu(), target_np[0], output_np[0]
                    ])  ## RGB, Segmentation label, segmentation, target_sdf.
                    kernel_visual.append([
                        data[0].data.cpu(), kernel_np[0],
                        coarse_estimation_np[0], learned_sdf_np[0],
                        target_sdf_np[0]
                    ])
                # PRINT INFO
                pixAcc, mIoU, _ = self._get_seg_metrics().values()
                tbar.set_description(
                    'EVAL ({}) | Seg_Loss: {:.3f}, Dis_Loss: {:.3f},PixelAcc: {:.2f}, Mean IoU: {:.2f} |'
                    .format(epoch, seg_loss, distance_loss, pixAcc, mIoU))

            # WRTING & VISUALIZING THE MASKS
            val_img = []
            palette = self.train_loader.dataset.palette
            for rgb, label, seg in val_visual:
                rgb = self.restore_transform(rgb)
                label, seg = colorize_mask(label, palette), colorize_mask(
                    seg, palette)
                rgb, label, seg = rgb.convert('RGB'), label.convert(
                    'RGB'), seg.convert('RGB')
                [rgb, label,
                 seg] = [self.viz_transform(x) for x in [rgb, label, seg]]
                val_img.extend([rgb, label, seg])
            kernel_img = []

            for rgb, kernel, pro, learned_sdf, sdf in kernel_visual:  ## RGB, kerenl,coarse_prediction, learned_field, target_field
                rgb, kernel, pro = self.restore_transform(
                    rgb), self.kernel_transform(kernel), self.kernel_transform(
                        pro)
                learned_sdf, sdf = self.learned_transform(learned_sdf, sdf)
                learned_sdf, sdf = self.to_PIL(learned_sdf), self.to_PIL(sdf)

                rgb, kernel, pro,learned_sdf,sdf= rgb.convert('RGB'), kernel.convert('RGB'), pro.convert('RGB'),\
                                                      learned_sdf.convert('RGB'),sdf.convert('RGB')
                [rgb, kernel, pro, learned_sdf, sdf] = [
                    self.viz_transform(x)
                    for x in [rgb, kernel, pro, learned_sdf, sdf]
                ]
                kernel_img.extend([rgb, kernel, pro, learned_sdf, sdf])

            val_img = torch.stack(val_img, 0)
            val_img = make_grid(val_img.cpu(), nrow=3, padding=5)

            kernel_img = torch.stack(kernel_img, 0)
            kernel_img = make_grid(kernel_img.cpu(), nrow=5, padding=5)
            self.writer.add_image(
                f'{self.wrt_mode}/inputs_targets_predictions', val_img,
                self.wrt_step)
            self.writer.add_image(f'{self.wrt_mode}/kernel_predictions',
                                  kernel_img, self.wrt_step)
            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar(f'{self.wrt_mode}/loss', loss.item(),
                                   self.wrt_step)
            self.writer.add_scalar(f'{self.wrt_mode}/seg_loss', 5 * seg_loss,
                                   self.wrt_step)
            self.writer.add_scalar(f'{self.wrt_mode}/dis_loss', distance_loss,
                                   self.wrt_step)
            seg_metrics = self._get_seg_metrics()
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar(f'{self.wrt_mode}/{k}', v,
                                       self.wrt_step)

            log = {'val_loss': self.total_loss.average, **seg_metrics}

        return log
示例#14
0
def main(args, config, resume):
    # DATA LOADERS
    global myloss, bordermark
    train_loader = get_instance(dataloaders, 'train_loader', config)
    val_loader = get_instance(dataloaders, 'val_loader', config)
    num_classes = train_loader.dataset.num_classes
    palette = train_loader.dataset.palette

    model = getattr(models, config['arch']['type'])(num_classes,
                                                    **config['arch']['args'])

    availble_gpus = list(range(torch.cuda.device_count()))
    device = torch.device('cuda:0' if len(availble_gpus) > 0 else 'cpu')
    checkpoint = torch.load(args.binname, map_location=torch.device('cpu'))
    checkpoint = checkpoint["state_dict"]

    model.set_mean_std(train_loader.MEAN, train_loader.STD)
    model = torch.nn.DataParallel(model)
    model.to(device)
    model.load_state_dict(checkpoint)
    model.eval()

    cnt = 0
    myloss = getattr(losses,
                     config['loss'])(ignore_index=config['ignore_index'])

    nor_m = METRICS(num_classes)
    atk_m1 = METRICS(num_classes)
    atk_m2 = METRICS(num_classes)
    atk_m3 = METRICS(num_classes)
    atk_m4 = METRICS(num_classes)

    for batch_idx, (data, targ) in enumerate(val_loader):
        cnt += 1

        torch.cuda.empty_cache()
        data = data.float()
        data = data.requires_grad_()

        targ = targ.cuda()

        adv = get_adv_examples_S(data, targ, model, myloss, 25, 10, 15,
                                 palette)
        output = model(adv)['out']
        seg_metrics = eval_metrics(output, targ.cuda(), num_classes)
        atk_m1.update_seg_metrics(*seg_metrics)
        pixAcc1, mIoU1, _ = atk_m1.get_seg_metrics().values()

        adv = get_adv_examples_S(data, targ, model, myloss, 50, 10, 15,
                                 palette)
        output = model(adv)['out']
        seg_metrics = eval_metrics(output, targ.cuda(), num_classes)
        atk_m2.update_seg_metrics(*seg_metrics)
        pixAcc2, mIoU2, _ = atk_m2.get_seg_metrics().values()

        adv = get_adv_examples_S(data, targ, model, myloss, 100, 10, 15,
                                 palette)
        output = model(adv)['out']
        seg_metrics = eval_metrics(output, targ.cuda(), num_classes)
        atk_m3.update_seg_metrics(*seg_metrics)
        pixAcc3, mIoU3, _ = atk_m3.get_seg_metrics().values()

        adv = get_adv_examples_S(data, targ, model, myloss, 150, 10, 20,
                                 palette)
        output = model(adv)['out']
        seg_metrics = eval_metrics(output, targ.cuda(), num_classes)
        atk_m4.update_seg_metrics(*seg_metrics)
        pixAcc4, mIoU4, _ = atk_m4.get_seg_metrics().values()

        with open('./results.txt', 'a') as f:
            print("ROUND %d" % (cnt), file=f)
            print("%f %f" % (pixAcc1, mIoU1), file=f)
            print("%f %f" % (pixAcc2, mIoU2), file=f)
            print("%f %f" % (pixAcc3, mIoU3), file=f)
            print("%f %f" % (pixAcc4, mIoU4), file=f)
            f.close()
示例#15
0
    def _train_epoch(self, epoch):

        self.model.train()
        if self.rank == 0:
            wrt_mode = 'train'

            y_true = []
            y_score = []
            y_score_b = []

            tic = time.time()
            self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols=160)
        for batch_idx, (img, gt, Sgt, Lgt, mask) in enumerate(tbar):
            if self.rank == 0: self.data_time.update(time.time() - tic)
            img = img.to(self.rank, non_blocking=True)
            gt = gt.to(self.rank, non_blocking=True)
            mask = mask.to(self.rank, non_blocking=True)
            Sgt = Sgt.to(self.rank, non_blocking=True)
            Lgt = Lgt.to(self.rank, non_blocking=True)
            # LOSS & OPTIMIZE

            self.optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=True):
                if self.gt_num == 1:
                    predict = self.model(img)
                    loss = self.loss(predict, gt)
                elif self.gt_num == 2:
                    s, predict = self.model(img)
                    loss = self.loss(predict, gt, s, Sgt)
                else:
                    l, s, predict = self.model(img)
                    loss = self.loss(predict, gt, s, Sgt, l, Lgt)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            if self.rank == 0:
                self.total_loss.update(loss.item())
                # measure elapsed time
                self.batch_time.update(time.time() - tic)
                tic = time.time()
                # LOGGING & TENSORBOARD
                if batch_idx % self.log_step == 0:
                    wrt_step = (epoch - 1) * len(self.train_loader) + batch_idx

                predict = torch.sigmoid(predict).cpu().detach().numpy().ravel()
                predict_b = np.where(predict >= 0.5, 1, 0)
                # predict_b = torch.where(predict >= 0.5, torch.full_like(predict, 1), torch.full_like(predict, 0))
                mask = mask.cpu().detach().numpy().ravel()
                y_true = gt.cpu().detach().numpy().ravel()[mask == 1]
                y_score = predict[mask == 1]
                y_score_b = predict_b[mask == 1]

                # FOR EVAL and INFO
                if self.rank == 0:
                    self._update_seg_metrics(*eval_metrics(y_true, y_score_b))
                    metrics = get_metrics(self.tn, self.fp, self.fn, self.tp)
                    tbar.set_description(
                        'TRAIN ({}) | Loss: {:.4f} | Acc {:.4f} Pre {:.4f} Sen {:.4f} Spe {:.4f} f1 {:.4f} IOU {:.4f} |B {:.2f} D {:.2f} |'.format(
                            epoch, self.total_loss.average, *metrics.values(), self.batch_time.average,
                            self.data_time.average))

            # METRICS TO TENSORBOARD
        if self.rank == 0:
            metrics = get_metrics_full(self.tn, self.fp, self.fn, self.tp, y_true, y_score, y_score_b)
            self.writer.add_scalar(f'{wrt_mode}/loss', self.total_loss.average, epoch)
            for k, v in list(metrics.items())[:-1]:
                self.writer.add_scalar(f'{wrt_mode}/{k}', v, epoch)
            for i, opt_group in enumerate(self.optimizer.param_groups):
                self.writer.add_scalar(f'{wrt_mode}/Learning_rate_{i}', opt_group['lr'], epoch)
            # self.writer.add_scalar(f'{self.wrt_mode}/Momentum_{k}', opt_group['momentum'], self.wrt_step)

        self.lr_scheduler.step()
示例#16
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'

        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            for batch_idx, (data, target) in enumerate(tbar):
                # data, target = data.to(self.device), target.to(self.device)
                # LOSS
                output = self.model(data)
                loss = self.loss(output, target)
                if isinstance(self.loss, torch.nn.DataParallel):
                    loss = loss.mean()
                self.total_loss.update(loss.item())

                seg_metrics = eval_metrics(output, target, self.num_classes)
                self._update_seg_metrics(*seg_metrics)

                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    target_np = target.data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    val_visual.append(
                        [data[0].data.cpu(), target_np[0], output_np[0]])

                # PRINT INFO
                pixAcc, mIoU, _ = self._get_seg_metrics().values()
                tbar.set_description(
                    'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'
                    .format(epoch, self.total_loss.average, pixAcc, mIoU))

            # WRTING & VISUALIZING THE MASKS
            val_img = []
            palette = self.train_loader.dataset.palette
            for d, t, o in val_visual:
                d = self.restore_transform(d)
                t, o = colorize_mask(t, palette), colorize_mask(o, palette)
                d, t, o = d.convert('RGB'), t.convert('RGB'), o.convert('RGB')
                [d, t, o] = [self.viz_transform(x) for x in [d, t, o]]
                val_img.extend([d, t, o])
            val_img = torch.stack(val_img, 0)
            val_img = make_grid(val_img.cpu(), nrow=3, padding=5)
            self.writer.add_image(
                '{}/inputs_targets_predictions'.format(self.wrt_mode), val_img,
                self.wrt_step)

            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar('{}/loss'.format(self.wrt_mode),
                                   self.total_loss.average, self.wrt_step)
            seg_metrics = self._get_seg_metrics()
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar('{}/{}'.format(self.wrt_mode, k), v,
                                       self.wrt_step)

            log = {'val_loss': self.total_loss.average, **seg_metrics}

        return log
示例#17
0
    def train(self):
        """The function for the pre-train phase."""

        # Set the pretrain log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['train_iou'] = []
        trlog['val_iou'] = []
        trlog['max_iou'] = 0.0
        trlog['max_iou_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Start pretrain
        for epoch in range(1, self.args.pre_max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            self.model.mode = 'train'
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()
            train_iou_averager = Averager()

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)

            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, label = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    label = batch[1]

                # Output logits for model
                logits = self.model(data)
                # Calculate train loss
                # CD loss is modified in the whole project to incorporate ony Cross Entropy loss. Modify as per requirement.
                #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label)
                loss = self.CD(logits, label)

                # Calculate train accuracy
                self._reset_metrics()
                seg_metrics = eval_metrics(logits, label,
                                           self.args.num_classes)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(
                    self.args.num_classes).values()

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(pixAcc)
                train_iou_averager.add(mIoU)

                # Print loss and accuracy till this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f} IOU={:.4f}'.format(
                        epoch, train_loss_averager.item(),
                        train_acc_averager.item() * 100.0,
                        train_iou_averager.item()))

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()
            train_iou_averager = train_iou_averager.item()

            writer.add_scalar('data/train_loss(Pre)',
                              float(train_loss_averager), epoch)
            writer.add_scalar('data/train_acc(Pre)',
                              float(train_acc_averager) * 100.0, epoch)
            writer.add_scalar('data/train_iou (Pre)',
                              float(train_iou_averager), epoch)

            print(
                'Epoch {}, Train: Loss={:.4f}, Acc={:.4f}, IoU={:.4f}'.format(
                    epoch, train_loss_averager, train_acc_averager * 100.0,
                    train_iou_averager))

            # Start validation for this epoch, set model to eval mode
            self.model.eval()
            self.model.mode = 'val'

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()
            val_iou_averager = Averager()

            # Print previous information
            if epoch % 1 == 0:
                print('Best Val Epoch {}, Best Val IoU={:.4f}'.format(
                    trlog['max_iou_epoch'], trlog['max_iou']))

            # Run validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, labels, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    label = labels[0]
                p = self.args.way * self.args.shot
                data_shot, data_query = data[:p], data[p:]
                label_shot, label = labels[:p], labels[p:]

                par = data_shot, label_shot, data_query
                logits = self.model(par)
                # Calculate preval loss

                #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label)
                loss = self.CD(logits, label)

                # Calculate val accuracy
                self._reset_metrics()
                seg_metrics = eval_metrics(logits, label, self.args.way)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()

                val_loss_averager.add(loss.item())
                val_acc_averager.add(pixAcc)
                val_iou_averager.add(mIoU)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            val_iou_averager = val_iou_averager.item()

            writer.add_scalar('data/val_loss(Pre)', float(val_loss_averager),
                              epoch)
            writer.add_scalar('data/val_acc(Pre)',
                              float(val_acc_averager) * 100.0, epoch)
            writer.add_scalar('data/val_iou (Pre)', float(val_iou_averager),
                              epoch)

            # Print loss and accuracy for this epoch
            print('Epoch {}, Val: Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(
                epoch, val_loss_averager, val_acc_averager * 100.0,
                val_iou_averager))

            # Update best saved model
            if val_iou_averager > trlog['max_iou']:
                trlog['max_iou'] = val_iou_averager
                trlog['max_iou_epoch'] = epoch
                print("model saved in max_iou")
                self.save_model('max_iou')

            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)
            trlog['train_iou'].append(train_iou_averager)
            trlog['val_iou'].append(val_iou_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 1 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
        writer.close()
示例#18
0
    def eval(self):
        """The function for the meta-evaluate (test) phase."""
        # Load the logs
        trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        # Load meta-test set
        self.test_set = mDataset('test', self.args)
        self.sampler = CategoriesSampler(self.test_set.labeln, self.args.num_batch, self.args.way, self.args.teshot + self.args.test_query, self.args.teshot)
        self.loader = DataLoader(dataset=self.test_set, batch_sampler=self.sampler, num_workers=8, pin_memory=True)
        #self.loader = DataLoader(dataset=self.test_set,batch_size=10, shuffle=False, num_workers=8, pin_memory=True)
        # Set test accuracy recorder
        #test_acc_record = np.zeros((600,))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(torch.load(self.args.eval_weights)['params'])
        else:
            self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_iou' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Start meta-test
        self._reset_metrics()
        count=1
        for i, batch in enumerate(self.loader, 1):
            if torch.cuda.is_available():
                data, labels,_ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
                labels=batch[1]
            p = self.args.teshot*self.args.way
            data_shot, data_query = data[:p], data[p:]
            label_shot,label=labels[:p],labels[p:]
            logits = self.model((data_shot, label_shot, data_query))
            seg_metrics = eval_metrics(logits, label, self.args.way)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()
            
            ave_acc.add(pixAcc)
            #test_acc_record[i-1] = acc
            #if i % 100 == 0:
                #print('batch {}: {Average Accuracy:.2f}({Pixel Accuracy:.2f} {IoU :.2f} )'.format(i, ave_acc.item() * 100.0, pixAcc * 100.0,mIoU))
                
            #Saving Test Image, Ground Truth Image and Predicted Image
            for j in range(len(data_query)):
                
                x1 = data_query[j].detach().cpu()
                y1 = label[j].detach().cpu()
                z1 = logits[j].detach().cpu()
                
                x = transforms.ToPILImage()(x1).convert("RGB")
                y = transforms.ToPILImage()(y1 /(1.0*(self.args.way-1))).convert("LA")
                im =  torch.tensor(np.argmax(np.array(z1),axis=0)/(1.0*(self.args.way-1))) 
                im =  im.type(torch.FloatTensor)
                z =  transforms.ToPILImage()(im).convert("LA")
                
                px=self.args.save_image_dir+str(count)+'a.jpg'
                py=self.args.save_image_dir+str(count)+'b.png'
                pz=self.args.save_image_dir+str(count)+'c.png'
                x.save(px)
                y.save(py)
                z.save(pz)
                count=count+1
示例#19
0
    def _train_fastadt(self, epoch):
        TEMP = 1

        self.logger.info('\n')

        self.model.train()
        if self.config['arch']['args']['freeze_bn']:
            if isinstance(self.model, torch.nn.DataParallel):
                self.model.module.freeze_bn()
            else:
                self.model.freeze_bn()
        self.wrt_mode = 'train'

        tic = time.time()
        self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols=130)
        for batch_idx, (data, target) in enumerate(tbar):
            torch.cuda.empty_cache()
            target[target >= self.num_classes] = 255
            target[target < 0] = 255
            self.data_time.update(time.time() - tic)

            self.lr_scheduler.step(epoch=epoch - 1)

            def step(x, g, alpha):
                l = len(x.shape) - 1
                g_norm = torch.norm(g.view(g.shape[0], -1), 1,
                                    dim=1).view(-1, *([1] * l))
                scaled_g = g / (g_norm + 1e-10)
                return x + scaled_g * alpha

            #alpha : step size
            #eps : total norm
            alpha = 250
            eps = 1000
            gamma = 1
            # LOSS & OPTIMIZE

            max_perturb = eps / (data.shape[1] * data.shape[2] * data.shape[3])
            delta = torch.zeros_like(data).uniform_(-max_perturb,
                                                    max_perturb).cuda()
            delta.requires_grad = True
            output = self.model(data + delta)['out'] / TEMP
            loss = self.loss(output, target)
            loss = loss.mean()
            loss.backward()
            grad = delta.grad.detach()
            delta.data = step(delta.data, grad, alpha).renorm(p=1,
                                                              dim=0,
                                                              maxnorm=eps)
            delta.data = torch.max(torch.min(1 - data, delta.data), 0 - data)
            delta = delta.detach()

            output = self.model(data + delta)['out'] / TEMP
            loss = self.loss(output, target)
            if isinstance(self.loss, torch.nn.DataParallel):
                loss = loss.mean()
            loss = loss.mean()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            self.total_loss.update(loss.item())

            # measure elapsed time
            self.batch_time.update(time.time() - tic)
            tic = time.time()

            # LOGGING & TENSORBOARD
            if batch_idx % self.log_step == 0:
                self.wrt_step = (epoch - 1) * len(
                    self.train_loader) + batch_idx
                self.writer.add_scalar(f'{self.wrt_mode}/loss', loss.item(),
                                       self.wrt_step)

            # FOR EVAL
            seg_metrics = eval_metrics(output, target, self.num_classes)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics().values()

            # PRINT INFO
            tbar.set_description(
                'TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.2f} |  '.
                format(epoch, self.total_loss.average, pixAcc, mIoU))

        # METRICS TO TENSORBOARD
        seg_metrics = self._get_seg_metrics()
        for k, v in list(seg_metrics.items())[:-1]:
            self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step)
        for i, opt_group in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(f'{self.wrt_mode}/Learning_rate_{i}',
                                   opt_group['lr'], self.wrt_step)
            #self.writer.add_scalar(f'{self.wrt_mode}/Momentum_{k}', opt_group['momentum'], self.wrt_step)

        # RETURN LOSS & METRICS
        log = {'loss': self.total_loss.average, **seg_metrics}

        #if self.lr_scheduler is not None: self.lr_scheduler.step()
        return log
示例#20
0
    def _train_CR1(self, epoch):
        TEMP = 1
        DCR_sum = 0
        self.logger.info('\n')

        self.model.train()
        if self.config['arch']['args']['freeze_bn']:
            if isinstance(self.model, torch.nn.DataParallel):
                self.model.module.freeze_bn()
            else:
                self.model.freeze_bn()
        self.wrt_mode = 'train'

        tic = time.time()
        self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols=130)
        for batch_idx, (data, target) in enumerate(tbar):
            torch.cuda.empty_cache()
            target[target >= self.num_classes] = 255
            target[target < 0] = 255
            self.data_time.update(time.time() - tic)

            self.lr_scheduler.step(epoch=epoch - 1)

            # LOSS & OPTIMIZE
            self.optimizer.zero_grad()
            output = self.model(data)
            output = output['out'] / TEMP

            if self.config['arch']['type'][:3] == 'PSP':
                assert output[0].size()[2:] == target.size()[1:]
                assert output[0].size()[1] == self.num_classes
                loss = self.loss(output[0], target)
                loss += self.loss(output[1], target) * 0.4
                output = output[0]
            else:
                assert output.size()[2:] == target.size()[1:]
                assert output.size()[1] == self.num_classes
                loss = self.loss(output, target)

            if isinstance(self.loss, torch.nn.DataParallel):
                loss = loss.mean()
            loss1 = loss.mean()

            from tools import get_CR
            alpha = 2
            gamma = 0.05
            CR1 = get_CR(output)

            dir = torch.randn(data.shape).cuda()
            data2 = torch.clamp(data + gamma * dir, 0, 1)
            CR2 = get_CR(self.model(data2)['out'] / TEMP)
            loss2 = alpha * torch.mean(torch.abs(CR2 - CR1))
            DCR_sum += loss2

            loss = loss1 + loss2

            loss.backward()
            self.optimizer.step()
            self.total_loss.update(loss.item())

            # measure elapsed time
            self.batch_time.update(time.time() - tic)
            tic = time.time()

            # LOGGING & TENSORBOARD
            if batch_idx % self.log_step == 0:
                self.wrt_step = (epoch - 1) * len(
                    self.train_loader) + batch_idx
                self.writer.add_scalar(f'{self.wrt_mode}/loss', loss.item(),
                                       self.wrt_step)

            # FOR EVAL
            seg_metrics = eval_metrics(output, target, self.num_classes)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics().values()

            # PRINT INFO
            tbar.set_description(
                'TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.2f} | B {:.3f} D {:.3f} |'
                .format(epoch, self.total_loss.average, pixAcc, mIoU, loss1,
                        loss2))
        with open('./result_ADVTR.txt', 'a') as file_handle:
            file_handle.write("%f " % (DCR_sum))
            file_handle.write('\n')

            file_handle.close()
        # METRICS TO TENSORBOARD
        seg_metrics = self._get_seg_metrics()
        for k, v in list(seg_metrics.items())[:-1]:
            self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step)
        for i, opt_group in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(f'{self.wrt_mode}/Learning_rate_{i}',
                                   opt_group['lr'], self.wrt_step)
            #self.writer.add_scalar(f'{self.wrt_mode}/Momentum_{k}', opt_group['momentum'], self.wrt_step)

        # RETURN LOSS & METRICS
        log = {'loss': self.total_loss.average, **seg_metrics}

        #if self.lr_scheduler is not None: self.lr_scheduler.step()
        return log
示例#21
0
    def _train_epoch(self, epoch):
        self.logger.info('\n')
        self.logger.info(f'Epoch: {epoch}')
        self.logger.info(
            f'Learning rate: {self.optimizer.param_groups[0]["lr"]}')

        self.model.train()
        if self.config['arch']['args']['freeze_bn']:
            if isinstance(self.model, torch.nn.DataParallel):
                self.model.module.freeze_bn()
            else:
                self.model.freeze_bn()
        self.wrt_mode = 'train'

        # tic = time.time()
        show_loss = []
        self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols=130)
        loader_size = len(self.train_loader)
        for batch_idx, (data, target) in enumerate(tbar):
            # self.data_time.update(time.time() - tic)
            data, target = data.to(self.device), target.to(self.device)

            # LOSS & OPTIMIZE
            if batch_idx % self.batch_stride == 0:
                self.optimizer.zero_grad()
            output = self.model(data)
            if self.config['arch']['type'][:3] == 'PSP':
                assert output[0].size()[2:] == target.size()[1:]
                assert output[0].size()[1] == self.num_classes
                loss = self.loss(output[0], target)
                loss += self.loss(output[1], target) * 0.4
                output = output[0]
            else:
                assert output.size()[2:] == target.size()[1:]
                assert output.size()[1] == self.num_classes
                loss = self.loss(output, target)

            if isinstance(self.loss, torch.nn.DataParallel):
                loss = loss.mean()

            loss_item = loss.item()
            loss = loss / self.batch_stride

            loss.backward()

            if (batch_idx + 1
                ) % self.batch_stride == 0 or batch_idx + 1 == loader_size:
                self.optimizer.step()

            self.total_loss.update(loss_item)
            show_loss.append(loss_item)

            # measure elapsed time
            # self.batch_time.update(time.time() - tic)
            # tic = time.time()

            # LOGGING & TENSORBOARD
            if batch_idx % self.log_step == 0:
                self.wrt_step = (epoch - 1) * len(
                    self.train_loader) + batch_idx
                self.writer.add_scalar(f'{self.wrt_mode}/loss',
                                       np.mean(show_loss), self.wrt_step)
                show_loss.clear()

            # FOR EVAL
            seg_metrics = eval_metrics(output, target, self.num_classes)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics().values()

            # PRINT INFO
            tbar.set_description(
                'TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.2f} |'.format(
                    epoch, self.total_loss.average, pixAcc, mIoU))

        # METRICS TO TENSORBOARD
        seg_metrics = self._get_seg_metrics()
        for k, v in list(seg_metrics.items())[:-1]:
            self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step)
        for i, opt_group in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(f'{self.wrt_mode}/Learning_rate_{i}',
                                   opt_group['lr'], self.wrt_step)
            #self.writer.add_scalar(f'{self.wrt_mode}/Momentum_{k}', opt_group['momentum'], self.wrt_step)

        # RETURN LOSS & METRICS
        log = {'loss': self.total_loss.average, **seg_metrics}

        if self.lr_scheduler is not None: self.lr_scheduler.step(epoch - 1)
        return log
示例#22
0
    def _train_epoch(self, epoch):
        self.logger.info('\n')

        self.model.train()
        if self.config['arch']['args']['freeze_bn']:
            if isinstance(self.model, torch.nn.DataParallel):
                self.model.module.freeze_bn()
            else:
                self.model.freeze_bn()
        self.wrt_mode = 'train'

        tic = time.time()
        self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols=130)
        for batch_idx, (data, target) in enumerate(tbar):
            self.data_time.update(time.time() - tic)
            #data, target = data.to(self.device), target.to(self.device)
            self.lr_scheduler.step(epoch=epoch - 1)

            # LOSS & OPTIMIZE
            self.optimizer.zero_grad()
            output = self.model(data)
            if self.config['arch']['type'][:3] == 'PSP':
                assert output[0].size()[2:] == target.size()[1:]
                assert output[0].size()[1] == self.num_classes
                loss = self.loss(output[0], target)
                loss += self.loss(output[1], target) * 0.4
                output = output[0]
            elif self.config['arch']['type'] == 'AUNet':
                loss = self.loss(output[0], target)
                for x in output[1:]:
                    loss += self.loss(
                        F.interpolate(x,
                                      size=target.shape[-2:],
                                      mode='bilinear',
                                      align_corners=True), target) * 0.2
                output = output[0]
            else:
                assert output.size()[2:] == target.size()[1:]
                assert output.size()[1] == self.num_classes
                loss = self.loss(output, target)

            if isinstance(self.loss, torch.nn.DataParallel):
                loss = loss.mean()
            loss.backward()
            self.optimizer.step()
            self.total_loss.update(loss.item())

            # measure elapsed time
            self.batch_time.update(time.time() - tic)
            tic = time.time()

            # LOGGING & TENSORBOARD
            if batch_idx % self.log_step == 0:
                self.wrt_step = (epoch - 1) * len(
                    self.train_loader) + batch_idx
                self.writer.add_scalar(f'{self.wrt_mode}/loss', loss.item(),
                                       self.wrt_step)

            # FOR EVAL
            seg_metrics = eval_metrics(output, target, self.num_classes)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics().values()

            # PRINT INFO
            tbar.set_description(
                'TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.2f} | B {:.2f} D {:.2f} |'
                .format(epoch, self.total_loss.average, pixAcc, mIoU,
                        self.batch_time.average, self.data_time.average))

        # METRICS TO TENSORBOARD
        seg_metrics = self._get_seg_metrics()
        for k, v in list(seg_metrics.items())[:-1]:
            self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step)
        for i, opt_group in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(f'{self.wrt_mode}/Learning_rate_{i}',
                                   opt_group['lr'], self.wrt_step)
            #self.writer.add_scalar(f'{self.wrt_mode}/Momentum_{k}', opt_group['momentum'], self.wrt_step)

        # RETURN LOSS & METRICS
        log = {'loss': self.total_loss.average, **seg_metrics}

        #if self.lr_scheduler is not None: self.lr_scheduler.step()
        return log
示例#23
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'
        total_loss_val = AverageMeter()
        total_inter, total_union = 0, 0
        total_correct, total_label = 0, 0

        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            for batch_idx, (data, target) in enumerate(tbar):
                target, data = target.cuda(non_blocking=True), data.cuda(
                    non_blocking=True)

                H, W = target.size(1), target.size(2)
                up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)
                pad_h, pad_w = up_sizes[0] - data.size(
                    2), up_sizes[1] - data.size(3)
                data = F.pad(data, pad=(0, pad_w, 0, pad_h), mode='reflect')
                output = self.model(data)
                output = output[:, :, :H, :W]

                # LOSS
                loss = F.cross_entropy(output,
                                       target,
                                       ignore_index=self.ignore_index)
                total_loss_val.update(loss.item())

                correct, labeled, inter, union = eval_metrics(
                    output, target, self.num_classes, self.ignore_index)
                total_inter, total_union = total_inter + inter, total_union + union
                total_correct, total_label = total_correct + correct, total_label + labeled

                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    if isinstance(data, list): data = data[0]
                    target_np = target.data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    val_visual.append(
                        [data[0].data.cpu(), target_np[0], output_np[0]])

                # PRINT INFO
                pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
                IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
                mIoU = IoU.mean()
                seg_metrics = {
                    "Pixel_Accuracy":
                    np.round(pixAcc, 3),
                    "Mean_IoU":
                    np.round(mIoU, 3),
                    "Class_IoU":
                    dict(zip(range(self.num_classes), np.round(IoU, 3)))
                }

                tbar.set_description(
                    'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'
                    .format(epoch, total_loss_val.average, pixAcc, mIoU))

            self._add_img_tb(val_visual, 'val')

            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar(f'{self.wrt_mode}/loss',
                                   total_loss_val.average, self.wrt_step)
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar(f'{self.wrt_mode}/{k}', v,
                                       self.wrt_step)

            log = {'val_loss': total_loss_val.average, **seg_metrics}
            self.html_results.add_results(epoch=epoch, seg_resuts=log)
            self.html_results.save()

            if (time.time() - self.start_time) / 3600 > 22:
                self._save_checkpoint(epoch, save_best=self.improved)
        return log
示例#24
0
    def test(self, epoch=1, log="wandb"):
        # self._test_reset_metrics()
        logger.info('###### TEST EVALUATION ######')
        wrt_mode = 'test'
        self.model.eval()
        tbar = tqdm(self.test_loader, ncols=50)
        imgs = []
        predicts = []
        gts = []
        masks = []
        tic1 = time.time()

        with torch.no_grad():
            for batch_idx, (img, gt, mask), in enumerate(tbar):
                img = img.cuda(non_blocking=True)
                gt = gt.cuda(non_blocking=True)
                mask = mask.cuda(non_blocking=True)
                with torch.cuda.amp.autocast(enabled=True):
                    if self.gt_num == 1:
                        predict = self.model(img)
                    elif self.gt_num == 2:
                        _, predict = self.model(img)
                    else:
                        _, _, predict = self.model(img)
                img = img.cpu().detach().numpy()
                mask = mask.cpu().detach().numpy()
                gt = gt.cpu().detach().numpy()
                predict = torch.sigmoid(predict).cpu().detach().numpy()

                imgs.extend(img)
                gts.extend(gt)
                predicts.extend(predict)
                masks.extend(mask)

            imgs = np.asarray(imgs)
            gts = np.asarray(gts)
            predicts = np.asarray(predicts)
            masks = np.asarray(masks)
        tic2 = time.time()
        test_time = tic2 - tic1
        logger.info(f'test time:  {test_time}')
        predicts_b = np.where(predicts >= 0.5, 1, 0)
        if self.save_pic is True:
            assert (predicts.shape[0] % self.group == 0)

            for i in range(int(predicts.shape[0] / self.group)):
                # orig_rgb_stripe = group_images(
                #     self.test_imgs_original[i * self.group:(i * self.group) + self.group, :, :, :], self.group)
                orig_stripe = group_images(imgs[i * self.group:(i * self.group) + self.group, :, :, :], self.group)
                gt_stripe = group_images(gts[i * self.group:(i * self.group) + self.group, :, :, :], self.group)
                pred_stripe = group_images(predicts_b[i * self.group:(i * self.group) + self.group, :, :, :],
                                           self.group)
                total_img = np.concatenate(
                    (np.tile(orig_stripe, 3), np.tile(gt_stripe, 3), np.tile(pred_stripe, 3)), axis=0)
                if self.save_pic is True:
                    visualize(total_img, self.checkpoint_dir + "/RGB_Original_GroundTruth_Prediction" + str(i))

        # LOGGING & TENSORBOARD
        wrt_step = epoch
        metrics = get_metrics_full(*eval_metrics(gts[masks == 1], predicts_b[masks == 1]), gts[masks == 1],
                                   predicts[masks == 1], predicts_b[masks == 1])
        # self.writer.add_image(f'{self.wrt_mode}/inputs_targets_predictions', total_img, self.wrt_step)
        tic3 = time.time()
        metrics_time = tic3 - tic1
        logger.info(f'metrics time:  {metrics_time}')
        # self.writer.add_scalar(f'{self.wrt_mode}/loss', self.total_loss.average, self.wrt_step)

        for k, v in list(metrics.items())[:-1]:
            if log == "wandb":
                wandb.log({f'{wrt_mode}/{k}': v})
            else:
                self.writer.add_scalar(f'{wrt_mode}/{k}', v, wrt_step)
        for k, v in metrics.items():
            logger.info(f'         {str(k):15s}: {v}')
示例#25
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'

        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            for batch_idx, (data, target, image_path) in enumerate(tbar):
                if self.device == torch.device("cuda:0"):
                    data, target = data.to(self.device), target.to(self.device)
                # LOSS
                output = self.model(data)

                # -----------------------mixed model
                # model_a = get_model().to(self.device)
                # output = 0.6 * output + 0.4 * model_a(data)

                # for i, o in enumerate(torch.argmax(output, dim=1).cpu().numpy()):
                #     if o in [7,8,12,15]:
                #         output_a = model_a(torch.unsqueeze(data[i],dim=0))
                #         output[i] = output[i] + output_a*0.5
                # ******
                loss = self.loss(output, target)
                if isinstance(self.loss, torch.nn.DataParallel):
                    loss = loss.mean()
                self.total_loss.update(loss.item())

                topk = eval_metrics(output, target, topk=(1, 2))
                self._update_metrics(topk)
                self.confusion_matrix = confusionMatrix(
                    output, target, self.confusion_matrix)

                # PRINT INFO
                tbar.set_description(
                    'EVAL ({}) | Loss: {:.3f} | Top1Acc {:.2f} Top2Acc {:.2f} |'
                    .format(epoch, self.total_loss.average,
                            self.precision_top1.average.item(),
                            self.precision_top2.average.item()))

            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar(f'{self.wrt_mode}/losses',
                                   self.total_loss.average,
                                   self.wrt_step)  # self.wrt_step
            self.writer.add_scalar(f'{self.wrt_mode}/top1',
                                   self.precision_top1.average.item(),
                                   self.wrt_step)  # self.wrt_step
            self.writer.add_scalar(f'{self.wrt_mode}/top2',
                                   self.precision_top2.average.item(),
                                   self.wrt_step)  # self.wrt_step

            # RETURN LOSS & METRICS
            log = {
                'losses': self.total_loss.average,
                "top1": self.precision_top1.average.item(),
                "top2": self.precision_top2.average.item()
            }

            # print confusion matrix
            confusion_file = open(
                os.path.join(self.checkpoint_dir, "confusion.txt"), 'a+')
            label_path = os.path.join(
                self.config["train_loader"]["args"]["data_dir"], "labels.txt")
            labels = []
            with open(label_path, 'r') as f:
                for line in f:
                    labels.append(line.split()[0])

            print("{0:10}".format(""), end="")
            confusion_file.write("{0:8}".format(""))
            for name in labels:
                print("{0:10}".format(name), end="")
                confusion_file.write("{0:8}".format(name))
            print("{0:10}".format("Precision"))
            confusion_file.write("{0:8}\n".format("Precision"))
            for i in range(self.train_loader.dataset.num_classes):
                print("{0:10}".format(labels[i]), end="")
                confusion_file.write("{0:8}".format(labels[i]))
                for j in range(self.train_loader.dataset.num_classes):
                    if i == j:
                        print("{0:10}".format(
                            str("-" + str(self.confusion_matrix[i][j])) + "-"),
                              end="")
                        confusion_file.write("{0:8}".format(
                            str("-" + str(self.confusion_matrix[i][j])) + "-"))
                    else:
                        print("{0:10}".format(str(
                            self.confusion_matrix[i][j])),
                              end="")
                        confusion_file.write("{0:8}".format(
                            str(self.confusion_matrix[i][j])))
                precision = 0.0 + self.confusion_matrix[i][i] / sum(
                    self.confusion_matrix[i])
                print("{0:.4f}".format(precision))
                confusion_file.write("{0:8}\n".format(precision))

        return log
示例#26
0
    def _train_epoch(self, epoch):
        self.logger.info('\n')

        # 是否freeze_bn
        self.model.train()
        if self.config['arch']['args']['freeze_bn']:
            if isinstance(self.model, torch.nn.DataParallel):
                self.model.module.freeze_bn()
            else:
                self.model.freeze_bn()
        self.wrt_mode = 'train'

        tic = time.time()
        self._reset_metrics()  # 重置指标:loss、top1、top2
        tbar = tqdm(self.train_loader, ncols=130)
        for batch_idx, (data, target, image_path) in enumerate(tbar):
            self.data_time.update(time.time() - tic)  # 读取数据的时间

            if self.device == torch.device('cuda:0'):
                data, target = data.to(self.device), target.to(self.device)
                self.loss.to(self.device)

            # LOSS & OPTIMIZE
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.loss(output, target)

            # 是否平均loss
            if isinstance(self.loss, torch.nn.DataParallel):
                loss = loss.mean()
            loss.backward()
            self.optimizer.step()
            self.total_loss.update(loss.item())  # 更新loss

            # measure elapsed time
            self.batch_time.update(time.time() - tic)  # batch训练的时间
            tic = time.time()

            # LOGGING & TENSORBOARD
            if batch_idx % self.log_step == 0:
                self.wrt_step = (epoch - 1) * len(
                    self.train_loader) + batch_idx
                self.writer.add_scalar(f'{self.wrt_mode}/losses', loss.item(),
                                       self.wrt_step)

            # FOR EVAL
            topk = eval_metrics(output, target, topk=(1, 2))  # topk is tensor
            self._update_metrics(topk)

            # PRINT INFO
            tbar.set_description(
                'TRAIN ({}) | Loss: {:.3f} | Top1Acc {:.2f} Top2Acc {:.2f} | B {:.2f} D {:.2f} |'
                .format(epoch, self.total_loss.average,
                        self.precision_top1.average.item(),
                        self.precision_top2.average.item(),
                        self.batch_time.average, self.data_time.average))

        # METRICS TO TENSORBOARD
        self.writer.add_scalar(f'{self.wrt_mode}/top1',
                               self.precision_top1.average.item(),
                               self.wrt_step)  # self.wrt_step
        self.writer.add_scalar(f'{self.wrt_mode}/top2',
                               self.precision_top2.average.item(),
                               self.wrt_step)  # self.wrt_step
        for i, opt_group in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(f'{self.wrt_mode}/Learning_rate_{i}',
                                   opt_group['lr'],
                                   self.wrt_step)  # self.wrt_step

        # RETURN LOSS & METRICS
        log = {
            'losses': self.total_loss.average,
            "top1": self.precision_top1.average.item(),
            "top2": self.precision_top2.average.item()
        }

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log
示例#27
0
    def _train_epoch(self, epoch):
        # torch.backends.cudnn.enabled = False
        self.logger.info('\n')

        self.model.train()
        if self.config['arch']['args']['freeze_bn']:
            if isinstance(self.model, torch.nn.DataParallel):
                self.model.module.freeze_bn()
            else:
                self.model.freeze_bn()
        self.wrt_mode = 'train'

        tic = time.time()
        self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols=130)
        for batch_idx, (data, target) in enumerate(tbar):
            self.data_time.update(time.time() - tic)  # 每次增加和上一轮结束时间时间的时间差
            # data, target = data.to(self.device), target.to(self.device)

            # LOSS & OPTIMIZE
            self.optimizer.zero_grad()
            output = self.model(data)
            if self.config['arch']['type'][:3] == 'PSP':
                assert output[0].size()[2:] == target.size()[1:]
                assert output[0].size()[1] == self.num_classes
                loss = self.loss(output[0], target)
                loss += self.loss(output[1],
                                  target) * 0.4  # 辅助分支的输出结果差异乘以一个较小权值
                output = output[0]
            else:
                assert output.size()[2:] == target.size()[1:]
                assert output.size()[1] == self.num_classes
                loss = self.loss(output, target)

            if isinstance(self.loss, torch.nn.DataParallel):
                loss = loss.mean()
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step(epoch=epoch - 1)
            self.total_loss.update(loss.item())

            # measure elapsed time
            self.batch_time.update(time.time() - tic)
            tic = time.time()

            # LOGGING & TENSORBOARD
            if batch_idx % self.log_step == 0:
                self.wrt_step = (epoch - 1) * len(
                    self.train_loader) + batch_idx
                self.writer.add_scalar('{}/loss'.format(self.wrt_mode),
                                       loss.item(), self.wrt_step)

            # EVALUATE BASED ON TRAINSET OUTPUT
            seg_metrics = eval_metrics(output, target, self.num_classes)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics().values()

            # PRINT INFO
            tbar.set_description(
                'TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.2f} | B {:.2f} D {:.2f} |'
                .format(epoch, self.total_loss.average, pixAcc, mIoU,
                        self.batch_time.average, self.data_time.average))

        # METRICS TO TENSORBOARD AFTER ONE EPOCH
        seg_metrics = self._get_seg_metrics()
        for k, v in list(seg_metrics.items())[:-1]:
            self.writer.add_scalar('{}/{}'.format(self.wrt_mode, k), v,
                                   self.wrt_step)
        for i, opt_group in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(
                '{}/Learning_rate_{}'.format(self.wrt_mode, i),
                opt_group['lr'], self.wrt_step)
            # self.writer.add_scalar(f'{self.wrt_mode}/Momentum_{k}', opt_group['momentum'], self.wrt_step)

        # RETURN LOSS & METRICS
        log = {'loss': self.total_loss.average, **seg_metrics}

        # if self.lr_scheduler is not None: self.lr_scheduler.step()
        return log
示例#28
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'

        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            cls_total_pix_correct = np.zeros(self.num_classes)
            cls_total_pix_labeled = np.zeros(self.num_classes)
            for batch_idx, (data, target) in enumerate(tbar):
                #data, target = data.to(self.device), target.to(self.device)
                # LOSS
                output = self.model(data)
                if self.config['arch']['type'][:2] == 'IC':
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output, target)
                    output = output[0]
                elif self.config['arch']['type'][-3:] == 'OCR':
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output[0], target)
                    loss += self.loss(output[1], target) * 0.4
                    output = output[0]
                elif 'Nearest' in self.config['arch']['type']:
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output[0], target)
                    loss += self.loss(output[1], target) * 0.4
                    output = output[0]
                elif self.config['arch']['type'][:3] == 'Enc':
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output, target)
                    output = output[0]
                elif self.config['arch']['type'][:5] == 'DANet':
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output[0], target)
                    loss += self.loss(output[1], target) * 0.2
                    loss += self.loss(output[2], target) * 0.2
                    output = output[0]
                else:
                    assert output.size()[2:] == target.size()[1:]
                    assert output.size()[1] == self.num_classes
                    loss = self.loss(output, target)

                if isinstance(self.loss, torch.nn.DataParallel):
                    loss = loss.mean()
                self.total_loss.update(loss.item())

                seg_metrics = eval_metrics(output, target, self.num_classes)
                self._update_seg_metrics(*seg_metrics)

                for i in range(self.num_classes):
                    cls_pix_correct, cls_pix_labeled = batch_class_pixel_accuracy(
                        output, target, i)
                    cls_total_pix_correct[i] += cls_pix_correct
                    cls_total_pix_labeled[i] += cls_pix_labeled
                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    target_np = target.data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    val_visual.append(
                        [data[0].data.cpu(), target_np[0], output_np[0]])

                # PRINT INFO
                pixAcc, mIoU, _ = self._get_seg_metrics().values()
                cls_pix_acc = np.round(
                    cls_total_pix_correct / cls_total_pix_labeled, 3)
                tbar.set_description(
                    'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f}, cls_pix_acc: {} |'
                    .format(epoch, self.total_loss.average, pixAcc, mIoU,
                            str(cls_pix_acc)))

            # WRTING & VISUALIZING THE MASKS
            val_img = []
            palette = self.train_loader.dataset.palette
            for d, t, o in val_visual:
                d = self.restore_transform(d)
                t, o = colorize_mask(t, palette), colorize_mask(o, palette)
                d, t, o = d.convert('RGB'), t.convert('RGB'), o.convert('RGB')
                [d, t, o] = [self.viz_transform(x) for x in [d, t, o]]
                val_img.extend([d, t, o])
            val_img = torch.stack(val_img, 0)
            val_img = make_grid(val_img.cpu(), nrow=3, padding=5)
            self.writer.add_image(
                f'{self.wrt_mode}/inputs_targets_predictions', val_img,
                self.wrt_step)

            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar(f'{self.wrt_mode}/loss',
                                   self.total_loss.average, self.wrt_step)
            seg_metrics = self._get_seg_metrics()
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar(f'{self.wrt_mode}/{k}', v,
                                       self.wrt_step)

            log = {'val_loss': self.total_loss.average, **seg_metrics}

        return log