コード例 #1
0
    def execute_on_dataloader(
            self, test_loader: torch.utils.data.dataloader.DataLoader):
        """
			Args:
			-   test_loader: 

			Returns:
			-   None
		"""
        if self.args.save_folder == 'default':
            self.args.save_folder = f'{_ROOT}/temp_files/{self.args.model_name}_{self.args.dataset}_universal_{self.scales_str}/{self.args.base_size}'

        os.makedirs(self.args.save_folder, exist_ok=True)
        gray_folder = os.path.join(self.args.save_folder, 'gray')
        self.gray_folder = gray_folder

        data_time = AverageMeter()
        batch_time = AverageMeter()
        end = time.time()

        check_mkdir(self.gray_folder)

        for i, (input, _) in enumerate(test_loader):
            logger.info(f'On image {i}')
            data_time.update(time.time() - end)

            # determine path for grayscale label map
            image_path, _ = self.data_list[i]
            if self.args.img_name_unique:
                image_name = Path(image_path).stem
            else:
                image_name = get_unique_stem_from_last_k_strs(image_path)
            gray_path = os.path.join(self.gray_folder, image_name + '.png')
            if Path(gray_path).exists():
                continue

            # convert Pytorch tensor -> Numpy, then feedforward
            input = np.squeeze(input.numpy(), axis=0)
            image = np.transpose(input, (1, 2, 0))
            gray_img = self.execute_on_img(image)

            batch_time.update(time.time() - end)
            end = time.time()
            cv2.imwrite(gray_path, gray_img)

            # todo: update to time remaining.
            if ((i + 1) % self.args.print_freq == 0) or (i + 1
                                                         == len(test_loader)):
                logger.info(
                    'Test: [{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}).'.
                    format(i + 1,
                           len(test_loader),
                           data_time=data_time,
                           batch_time=batch_time))
コード例 #2
0
    def execute_on_dataloader(
            self, test_loader: torch.utils.data.dataloader.DataLoader):
        """
            Args:
            -   test_loader: 

            Returns:
            -   None
        """
        if self.args.save_folder == 'default':
            self.args.save_folder = f'{_ROOT}/temp_files/{self.args.model_name}_{self.args.dataset}_universal_{self.scales_str}/{self.args.base_size}'

        os.makedirs(self.args.save_folder, exist_ok=True)
        gray_folder = os.path.join(self.args.save_folder, 'gray')
        self.gray_folder = gray_folder
        check_mkdir(self.gray_folder)

        data_time = AverageMeter()
        batch_time = AverageMeter()
        end = time.time()
        results = dict()  # path: label_map

        for i, (input, _) in enumerate(tqdm.tqdm(test_loader)):
            data_time.update(time.time() - end)
            # convert Pytorch tensor -> Numpy
            input = np.squeeze(input.numpy(), axis=0)
            image = np.transpose(input, (1, 2, 0))
            gray_img = self.execute_on_img_single(image)
            batch_time.update(time.time() - end)
            end = time.time()
            image_name, _ = self.data_list[i]
            img_id = image_name[len(self.input_file):]
            results[img_id] = gray_img

            # todo: update to time remaining.
            if 0 and ((i + 1) % self.args.print_freq
                      == 0) or (i + 1 == len(test_loader)):
                logger.info(
                    'Test: [{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}).'.
                    format(i + 1,
                           len(test_loader),
                           data_time=data_time,
                           batch_time=batch_time))
        mmcv.dump(results, os.path.join(gray_folder, 'label_maps.pkl'))
コード例 #3
0
    def execute_on_dataloader_batched(
            self, test_loader: torch.utils.data.dataloader.DataLoader):
        """ Optimize throughput through the network by batched inference, instead of single image inference
		"""
        if self.args.save_folder == 'default':
            self.args.save_folder = f'{_ROOT}/temp_files/{self.args.model_name}_{self.args.dataset}_universal_{self.scales_str}/{self.args.base_size}'

        os.makedirs(self.args.save_folder, exist_ok=True)
        gray_folder = os.path.join(self.args.save_folder, 'gray')
        self.gray_folder = gray_folder

        data_time = AverageMeter()
        batch_time = AverageMeter()
        end = time.time()

        check_mkdir(self.gray_folder)

        for i, (input, _) in enumerate(test_loader):
            logger.info(f"On batch {i}")
            data_time.update(time.time() - end)

            gray_batch = self.execute_on_batch(input)
            batch_sz = input.shape[0]
            # dump results to disk
            for j in range(batch_sz):
                # determine path for grayscale label map
                image_path, _ = self.data_list[i * self.args.batch_size_val +
                                               j]
                if self.args.img_name_unique:
                    image_name = Path(image_path).stem
                else:
                    image_name = get_unique_stem_from_last_k_strs(image_path)
                gray_path = os.path.join(self.gray_folder, image_name + '.png')
                cv2.imwrite(gray_path, gray_batch[j])

            batch_time.update(time.time() - end)
            end = time.time()

            if ((i + 1) % self.args.print_freq == 0) or (i + 1
                                                         == len(test_loader)):
                logger.info(
                    f'Test: [{i+1}/{len(test_loader)}] '
                    f'Data {data_time.val:.3f} (avg={data_time.avg:.3f})'
                    f'Batch {batch_time.val:.3f} (avg={batch_time.avg:.3f})')
コード例 #4
0
ファイル: train.py プロジェクト: xingyizhou/mseg-semantic
def validate(val_loader, model, criterion):
    if main_process():
        logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    sam = SegmentationAverageMeter()

    model.eval()
    if main_process():
        end = time.time()
    for i, (input, target) in enumerate(val_loader):
        if main_process():
            data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        output = model(input)
        if args.zoom_factor != 8:
            output = F.interpolate(output, size=target.size()[1:], mode='bilinear', align_corners=True)
        loss = criterion(output, target)

        n = input.size(0)
        if args.multiprocessing_distributed:
            loss = loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(loss), dist.all_reduce(count)
            n = count.item()
            loss = loss / n
        else:
            loss = torch.mean(loss)

        output = output.max(1)[1]
        sam.update_metrics_gpu(output, target, args.classes, args.ignore_label, args.multiprocessing_distributed)
        loss_meter.update(loss.item(), input.size(0))
        if main_process():
            batch_time.update(time.time() - end)
            end = time.time()
        if ((i + 1) % args.print_freq == 0) and main_process():
            logger.info('Test: [{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) '
                        'Accuracy {accuracy:.4f}.'.format(i + 1, len(val_loader),
                                                          data_time=data_time,
                                                          batch_time=batch_time,
                                                          loss_meter=loss_meter,
                                                          accuracy=sam.accuracy))

    iou_class, accuracy_class, mIoU, mAcc, allAcc = sam.get_metrics()
    if main_process():
        logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))
        for i in range(args.classes):
            logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i]))
        logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
    return loss_meter.avg, mIoU, mAcc, allAcc
コード例 #5
0
ファイル: train.py プロジェクト: xingyizhou/mseg-semantic
def train(train_loader, model, optimizer, epoch: int):
    """
    No MGDA -- whole iteration takes 0.31 sec.
    0.24 sec to run typical backward pass (with no MGDA)

    With MGDA -- whole iteration takes 1.10 sec.
    1.05 sec to run backward pass w/ MGDA subroutine -- scale_loss_and_gradients() in every iteration.

    TODO: Profile which part of Frank-Wolfe is slow

    """
    import torch, os, math, time
    import torch.distributed as dist

    from mseg_semantic.multiobjective_opt.dist_mgda_utils import scale_loss_and_gradients
    from mseg_semantic.utils.avg_meter import AverageMeter, SegmentationAverageMeter
    from mseg_semantic.utils.training_utils import poly_learning_rate

    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    sam = SegmentationAverageMeter()

    model.train()
    # set bn to be eval() and see the norm
    # def set_bn_eval(m):
    #     classname = m.__class__.__name__
    #     if classname.find('BatchNorm') != -1:
    #         m.eval()
    # model.apply(set_bn_eval)
    end = time.time()
    max_iter = args.max_iters
    for i, (input, target) in enumerate(train_loader):
        # pass
        # if main_process():
        data_time.update(time.time() - end)
        if args.zoom_factor != 8:
            h = int((target.size()[1] - 1) / 8 * args.zoom_factor + 1)
            w = int((target.size()[2] - 1) / 8 * args.zoom_factor + 1)
            # 'nearest' mode doesn't support align_corners mode and 'bilinear' mode is fine for downsampling
            target = F.interpolate(target.unsqueeze(1).float(), size=(h, w), mode='bilinear', align_corners=True).squeeze(1).long()
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        if args.use_mgda:
            output, loss, main_loss, aux_loss, scales = forward_backward_mgda(input, target, model, optimizer, args)
        else:
            output, loss, main_loss, aux_loss = forward_backward_full_sync(input, target, model, optimizer, args)
        optimizer.step()

        n = input.size(0)
        if args.multiprocessing_distributed:
            main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
            n = count.item()
            main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n

        sam.update_metrics_gpu(output, target, args.classes, args.ignore_label, args.multiprocessing_distributed)

        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        # if main_process():
        if i > 0:
            batch_time.update(time.time() - end)
        end = time.time()

        # print(len(train_loader))
        # logger.info(len(train_loader))


        current_iter = epoch * len(train_loader) + i + 1 + args.resume_iter
        current_lr = poly_learning_rate(args.base_lr, current_iter, max_iter, power=args.power)
        # current_lr = 0
        # logger.info(f'LR:{current_lr}, base_lr: {args.base_lr}, current_iter:{current_iter}, max_iter:{max_iter}, power:{args.power}')

        if args.arch == 'psp':
            for index in range(0, args.index_split):
                optimizer.param_groups[index]['lr'] = current_lr
            for index in range(args.index_split, len(optimizer.param_groups)):
                if args.finetune:
                    optimizer.param_groups[index]['lr'] = current_lr 
                else:
                    optimizer.param_groups[index]['lr'] = current_lr * 10

        elif args.arch == 'hrnet' or args.arch == 'hrnet_ocr':
            optimizer.param_groups[0]['lr'] = current_lr

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if (current_iter) % args.print_freq == 0 and True:
        # if True:
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'MainLoss {main_loss_meter.val:.4f} '
                        'AuxLoss {aux_loss_meter.val:.4f} '
                        'LR {current_lr:.8f} '
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader),
                                                          batch_time=batch_time,
                                                          data_time=data_time,
                                                          remain_time=remain_time,
                                                          main_loss_meter=main_loss_meter,
                                                          aux_loss_meter=aux_loss_meter,
                                                          current_lr=current_lr,
                                                          loss_meter=loss_meter,
                                                          accuracy=sam.accuracy) + f'current_iter: {current_iter}' + f' rank: {args.rank} ')
            if args.use_mgda and main_process():
                # Scales identical in each process, so print out only in main process.
                scales_str = [f'{d}: {scale:.2f}' for d,scale in scales.items()]
                scales_str = ' , '.join(scales_str)
                logger.info(f'Scales: {scales_str}')

        if main_process() and current_iter == max_iter - 5: # early exit to prevent iter number not matching between gpus
            break
        # if main_process():
        #     writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
        #     writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
        #     writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
        #     writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class, accuracy_class, mIoU, mAcc, allAcc = sam.get_metrics()
    # if main_process():
    logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch+1, args.epochs, mIoU, mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc