예제 #1
0
    def __init__(self, backbone, in_channels, feature_key, low_level_channels,
                 low_level_key, low_level_channels_project, decoder_channels,
                 atrous_rates, num_classes, semantic_loss,
                 semantic_loss_weight, center_loss, center_loss_weight,
                 offset_loss, offset_loss_weight, **kwargs):
        decoder = PanopticDeepLabDecoder(in_channels, feature_key,
                                         low_level_channels, low_level_key,
                                         low_level_channels_project,
                                         decoder_channels, atrous_rates,
                                         num_classes, **kwargs)
        super(PanopticDeepLab, self).__init__(backbone, decoder)

        self.semantic_loss = semantic_loss
        self.semantic_loss_weight = semantic_loss_weight
        self.loss_meter_dict = OrderedDict()
        self.loss_meter_dict['Loss'] = AverageMeter()
        self.loss_meter_dict['Semantic loss'] = AverageMeter()

        if kwargs.get('has_instance', False):
            self.center_loss = center_loss
            self.center_loss_weight = center_loss_weight
            self.offset_loss = offset_loss
            self.offset_loss_weight = offset_loss_weight
            self.loss_meter_dict['Center loss'] = AverageMeter()
            self.loss_meter_dict['Offset loss'] = AverageMeter()
        else:
            self.center_loss = None
            self.center_loss_weight = 0
            self.offset_loss = None
            self.offset_loss_weight = 0

        # Initialize parameters.
        self._init_params()
예제 #2
0
def train(train_loader, net, criterion, optimizer, epoch, train_args):
    net.train()
    train_loss = AverageMeter()
    curr_iter = (epoch - 1) * len(train_loader)
    for i, data in enumerate(train_loader):
        inputs, labels = data
        print("#################")
        print(inputs.size(), labels.size())
        print(labels.numpy())
        assert inputs.size()[2:] == labels.size()[1:]
        N = inputs.size(0)
        h, w = labels.size()[1:]
        inputs = Variable(inputs).cuda()
        labels = Variable(labels).cuda()

        optimizer.zero_grad()
        outputs = net(inputs)
        print(outputs.size())
        assert outputs.size()[2:] == labels.size()[1:]
        assert outputs.size()[1] == num_classes

        loss = criterion(outputs, labels) / N / h / w
        loss.backward()
        optimizer.step()

        train_loss.update(loss.data[0], N)

        curr_iter += 1
        #writer.add_scalar('train_loss', train_loss.avg, curr_iter)

        if (i + 1) % train_args['print_freq'] == 0:
            print('[epoch %d], [iter %d / %d], [train loss %.5f]' %
                  (epoch, i + 1, len(train_loader), train_loss.avg))
예제 #3
0
    def __init__(self, backbone, in_channels, feature_key, decoder_channels,
                 atrous_rates, num_classes, semantic_loss,
                 semantic_loss_weight, **kwargs):
        decoder = DeepLabV3Decoder(in_channels, feature_key, decoder_channels,
                                   atrous_rates, num_classes)
        super(DeepLabV3, self).__init__(backbone, decoder)

        self.semantic_loss = semantic_loss
        self.semantic_loss_weight = semantic_loss_weight

        self.loss_meter_dict = OrderedDict()
        self.loss_meter_dict['Loss'] = AverageMeter()

        # Initialize parameters.
        self._init_params()
예제 #4
0
def main():
    args = parse_args()

    logger = logging.getLogger('segmentation')
    if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called
        setup_logger(output=config.OUTPUT_DIR, distributed_rank=args.local_rank)

    logger.info(pprint.pformat(args))
    logger.info(config)

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = len(gpus) > 1
    device = torch.device('cuda:{}'.format(args.local_rank))

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://",
        )

    # build model
    model = build_segmentation_model_from_cfg(config)
    logger.info("Model:\n{}".format(model))

    logger.info("Rank of current process: {}. World size: {}".format(comm.get_rank(), comm.get_world_size()))

    if distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(device)

    if comm.get_world_size() > 1:
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank
        )

    data_loader = build_train_loader_from_cfg(config)
    optimizer = build_optimizer(config, model)
    lr_scheduler = build_lr_scheduler(config, optimizer)

    data_loader_iter = iter(data_loader)

    start_iter = 0
    max_iter = config.TRAIN.MAX_ITER
    best_param_group_id = get_lr_group_id(optimizer)

    # initialize model
    if os.path.isfile(config.MODEL.WEIGHTS):
        model_weights = torch.load(config.MODEL.WEIGHTS)
        get_module(model, distributed).load_state_dict(model_weights, strict=False)
        logger.info('Pre-trained model from {}'.format(config.MODEL.WEIGHTS))
    elif not config.MODEL.BACKBONE.PRETRAINED:
        if os.path.isfile(config.MODEL.BACKBONE.WEIGHTS):
            pretrained_weights = torch.load(config.MODEL.BACKBONE.WEIGHTS)
            get_module(model, distributed).backbone.load_state_dict(pretrained_weights, strict=False)
            logger.info('Pre-trained backbone from {}'.format(config.MODEL.BACKBONE.WEIGHTS))
        else:
            logger.info('No pre-trained weights for backbone, training from scratch.')

    # load model
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(config.OUTPUT_DIR, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            start_iter = checkpoint['start_iter']
            get_module(model, distributed).load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            logger.info('Loaded checkpoint (starting from iter {})'.format(checkpoint['start_iter']))

    data_time = AverageMeter()
    batch_time = AverageMeter()
    loss_meter = AverageMeter()

    # Debug output.
    if config.DEBUG.DEBUG:
        debug_out_dir = os.path.join(config.OUTPUT_DIR, 'debug_train')
        PathManager.mkdirs(debug_out_dir)

    # Train loop.
    try:
        for i in range(start_iter, max_iter):
            # data
            start_time = time.time()
            data = next(data_loader_iter)
            if not distributed:
                data = to_cuda(data, device)
            _data_time = time.time()
            data_time.update(_data_time - start_time)

            image = data.pop('image')
            out_dict = model(image, data)

            loss = out_dict['loss']

            torch.cuda.synchronize(device)
            _forward_time = time.time()
            if args.gpumem:
                gpumem = torch.cuda.memory_allocated(device)
                peak_usage = torch.cuda.max_memory_allocated(device)
                torch.cuda.reset_peak_memory_stats(device)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Get lr.
            lr = optimizer.param_groups[best_param_group_id]["lr"]
            lr_scheduler.step()

            _batch_time = time.time()
            batch_time.update(_batch_time - start_time)
            loss_meter.update(loss.detach().cpu().item(), image.size(0))

            if args.timing:
                logger.info('timing - forward %f' % (_forward_time - _data_time))
                logger.info('timing - both %f' % (_batch_time - _data_time))
            if args.gpumem:
                logger.info('gpumem - %f' % gpumem)
                logger.info('gpumem - peak %f' % peak_usage)


            if i == 0 or (i + 1) % config.PRINT_FREQ == 0:
                msg = '[{0}/{1}] LR: {2:.7f}\t' \
                      'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                      'Data: {data_time.val:.3f}s ({data_time.avg:.3f}s)\t'.format(
                        i + 1, max_iter, lr, batch_time=batch_time, data_time=data_time)
                msg += get_loss_info_str(get_module(model, distributed).loss_meter_dict)
                logger.info(msg)
            if i == 0 or (i + 1) % config.DEBUG.DEBUG_FREQ == 0:
                # TODO: Add interface for save_debug_images
                # if comm.is_main_process() and config.DEBUG.DEBUG:
                #     save_debug_images(
                #         dataset=data_loader.dataset,
                #         batch_images=image,
                #         batch_targets=data,
                #         batch_outputs=out_dict,
                #         out_dir=debug_out_dir,
                #         iteration=i,
                #         target_keys=config.DEBUG.TARGET_KEYS,
                #         output_keys=config.DEBUG.OUTPUT_KEYS,
                #         iteration_to_remove=i - config.DEBUG.KEEP_INTERVAL
                #     )
                if i>0 and (args.gpumem or args.timing):
                    break
            if i == 0 or (i + 1) % config.CKPT_FREQ == 0:
                if comm.is_main_process():
                    torch.save({
                        'start_iter': i + 1,
                        'state_dict': get_module(model, distributed).state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                    }, os.path.join(config.OUTPUT_DIR, 'checkpoint.pth.tar'))
    except Exception:
        logger.exception("Exception during training:")
        raise
    finally:
        if comm.is_main_process():
            torch.save(get_module(model, distributed).state_dict(),
                       os.path.join(config.OUTPUT_DIR, 'final_state.pth'))
        logger.info("Training finished.")
예제 #5
0
def main():
    args = parse_args()

    logger = logging.getLogger('demo')
    if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called
        setup_logger(output=args.output_dir, name='demo')

    logger.info(pprint.pformat(args))
    logger.info(config)

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.TEST.GPUS)
    if len(gpus) > 1:
        raise ValueError('Test only supports single core.')
    device = torch.device('cuda:{}'.format(gpus[0]))

    # build model
    model = build_segmentation_model_from_cfg(config)

    # Change ASPP image pooling
    # output_stride = 2 ** (5 - sum(config.MODEL.BACKBONE.DILATION))
    # train_crop_h, train_crop_w = config.TEST.CROP_SIZE
    # scale = 1. / output_stride
    # pool_h = int((float(train_crop_h) - 1.0) * scale + 1.0)
    # pool_w = int((float(train_crop_w) - 1.0) * scale + 1.0)

    # model.set_image_pooling((pool_h, pool_w))

    logger.info("Model:\n{}".format(model))
    model = model.to(device)

    try:
        # build data_loader
        data_loader = build_test_loader_from_cfg(config)
        meta_dataset = data_loader.dataset
        save_intermediate_outputs = True
    except:
        logger.warning(
            "Cannot build data loader, using default meta data. This will disable visualizing intermediate outputs"
        )
        if 'cityscapes' in config.DATASET.DATASET:
            meta_dataset = CityscapesMeta()
        else:
            raise ValueError("Unsupported dataset: {}".format(
                config.DATASET.DATASET))
        save_intermediate_outputs = False

    # load model
    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(config.OUTPUT_DIR, 'final_state.pth')

    if os.path.isfile(model_state_file):
        model_weights = torch.load(model_state_file)
        if 'state_dict' in model_weights.keys():
            model_weights = model_weights['state_dict']
            logger.info('Evaluating a intermediate checkpoint.')
        model.load_state_dict(model_weights, strict=False)
        logger.info('Test model loaded from {}'.format(model_state_file))
    else:
        if not config.DEBUG.DEBUG:
            raise ValueError('Cannot find test model.')

    # load images
    input_list = []
    if os.path.exists(args.input_files):
        if os.path.isfile(args.input_files):
            # inference on a single file, extract extension
            ext = os.path.splitext(os.path.basename(args.input_files))[1]
            if ext in ['.png', '.jpg', '.jpeg']:
                # image file
                input_list.append(args.input_files)
            elif ext in ['.mpeg']:
                # video file
                # TODO: decode video and convert to image list
                raise NotImplementedError(
                    "Inference on video is not supported yet.")
            else:
                raise ValueError("Unsupported extension: {}.".format(ext))
        else:
            # inference on a directory
            for fname in glob.glob(
                    os.path.join(args.input_files, '*' + args.extension)):
                input_list.append(fname)
    else:
        raise ValueError('Input file or directory does not exists: {}'.format(
            args.input_files))

    if isinstance(input_list[0], str):
        logger.info("Inference on images")
        logger.info(input_list)
    else:
        logger.info("Inference on video")

    # dir to save intermediate raw outputs
    raw_out_dir = os.path.join(args.output_dir, 'raw')
    PathManager.mkdirs(raw_out_dir)

    # dir to save semantic outputs
    semantic_out_dir = os.path.join(args.output_dir, 'semantic')
    PathManager.mkdirs(semantic_out_dir)

    # dir to save instance outputs
    instance_out_dir = os.path.join(args.output_dir, 'instance')
    PathManager.mkdirs(instance_out_dir)

    # dir to save panoptic outputs
    panoptic_out_dir = os.path.join(args.output_dir, 'panoptic')
    PathManager.mkdirs(panoptic_out_dir)

    # Test loop
    model.eval()

    # build image demo transform
    transforms = T.Compose(
        [T.ToTensor(),
         T.Normalize(config.DATASET.MEAN, config.DATASET.STD)])

    net_time = AverageMeter()
    post_time = AverageMeter()
    try:
        with torch.no_grad():
            for i, fname in enumerate(input_list):
                if isinstance(fname, str):
                    # load image
                    raw_image = read_image(fname, 'RGB')
                else:
                    NotImplementedError(
                        "Inference on video is not supported yet.")

                # pad image
                raw_shape = raw_image.shape[:2]
                raw_h = raw_shape[0]
                raw_w = raw_shape[1]
                new_h = (raw_h + 31) // 32 * 32 + 1
                new_w = (raw_w + 31) // 32 * 32 + 1
                input_image = np.zeros((new_h, new_w, 3), dtype=np.uint8)
                input_image[:, :] = config.DATASET.MEAN
                input_image[:raw_h, :raw_w, :] = raw_image

                image, _ = transforms(input_image, None)
                image = image.unsqueeze(0).to(device)

                # network
                start_time = time.time()
                out_dict = model(image)
                torch.cuda.synchronize(device)
                net_time.update(time.time() - start_time)

                # post-processing
                start_time = time.time()
                semantic_pred = get_semantic_segmentation(out_dict['semantic'])

                panoptic_pred, center_pred = get_panoptic_segmentation(
                    semantic_pred,
                    out_dict['center'],
                    out_dict['offset'],
                    thing_list=meta_dataset.thing_list,
                    label_divisor=meta_dataset.label_divisor,
                    stuff_area=config.POST_PROCESSING.STUFF_AREA,
                    void_label=(meta_dataset.label_divisor *
                                meta_dataset.ignore_label),
                    threshold=config.POST_PROCESSING.CENTER_THRESHOLD,
                    nms_kernel=config.POST_PROCESSING.NMS_KERNEL,
                    top_k=config.POST_PROCESSING.TOP_K_INSTANCE,
                    foreground_mask=None)
                torch.cuda.synchronize(device)
                post_time.update(time.time() - start_time)

                logger.info(
                    '[{}/{}]\t'
                    'Network Time: {net_time.val:.3f}s ({net_time.avg:.3f}s)\t'
                    'Post-processing Time: {post_time.val:.3f}s ({post_time.avg:.3f}s)\t'
                    .format(i,
                            len(input_list),
                            net_time=net_time,
                            post_time=post_time))

                # save predictions
                semantic_pred = semantic_pred.squeeze(0).cpu().numpy()
                panoptic_pred = panoptic_pred.squeeze(0).cpu().numpy()

                # crop predictions
                semantic_pred = semantic_pred[:raw_h, :raw_w]
                panoptic_pred = panoptic_pred[:raw_h, :raw_w]

                if save_intermediate_outputs:
                    # Raw outputs
                    save_debug_images(
                        dataset=meta_dataset,
                        batch_images=image,
                        batch_targets={},
                        batch_outputs=out_dict,
                        out_dir=raw_out_dir,
                        iteration=i,
                        target_keys=[],
                        output_keys=['semantic', 'center', 'offset'],
                        is_train=False,
                    )

                save_annotation(semantic_pred,
                                semantic_out_dir,
                                'semantic_pred_%d' % i,
                                add_colormap=True,
                                colormap=meta_dataset.create_label_colormap(),
                                image=raw_image if args.merge_image else None)
                pan_to_sem = panoptic_pred // meta_dataset.label_divisor
                save_annotation(pan_to_sem,
                                semantic_out_dir,
                                'panoptic_to_semantic_pred_%d' % i,
                                add_colormap=True,
                                colormap=meta_dataset.create_label_colormap(),
                                image=raw_image if args.merge_image else None)
                ins_id = panoptic_pred % meta_dataset.label_divisor
                pan_to_ins = panoptic_pred.copy()
                pan_to_ins[ins_id == 0] = 0
                save_instance_annotation(
                    pan_to_ins,
                    instance_out_dir,
                    'panoptic_to_instance_pred_%d' % i,
                    image=raw_image if args.merge_image else None)
                save_panoptic_annotation(
                    panoptic_pred,
                    panoptic_out_dir,
                    'panoptic_pred_%d' % i,
                    label_divisor=meta_dataset.label_divisor,
                    colormap=meta_dataset.create_label_colormap(),
                    image=raw_image if args.merge_image else None)
    except Exception:
        logger.exception("Exception during demo:")
        raise
    finally:
        logger.info("Demo finished.")
        if save_intermediate_outputs:
            logger.info("Intermediate outputs saved to {}".format(raw_out_dir))
        logger.info(
            "Semantic predictions saved to {}".format(semantic_out_dir))
        logger.info(
            "Instance predictions saved to {}".format(instance_out_dir))
        logger.info(
            "Panoptic predictions saved to {}".format(panoptic_out_dir))
예제 #6
0
def main():
    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.TEST.GPUS)
    if len(gpus) > 1:
        raise ValueError('Test only supports single core.')
    device = torch.device('cuda:{}'.format(gpus[0]))
    # build model
    model = build_segmentation_model_from_cfg(config)
    logger.info("Model:\n{}".format(model))
    model = model.to(device)

    # build data_loader
    # TODO: still need it for thing_list
    data_loader = build_test_loader_from_cfg(config)

    # load model
    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(config.OUTPUT_DIR, 'final_state.pth')

    if os.path.isfile(model_state_file):
        model_weights = torch.load(model_state_file)
        if 'state_dict' in model_weights.keys():
            model_weights = model_weights['state_dict']
            logger.info('Evaluating a intermediate checkpoint.')
        model.load_state_dict(model_weights, strict=True)
        logger.info('Test model loaded from {}'.format(model_state_file))
    else:
        if not config.DEBUG.DEBUG:
            raise ValueError('Cannot find test model.')

    # dir to save intermediate raw outputs
    raw_out_dir = os.path.join(args.output_dir, 'raw')
    PathManager.mkdirs(raw_out_dir)

    # dir to save semantic outputs
    semantic_out_dir = os.path.join(args.output_dir, 'semantic')
    PathManager.mkdirs(semantic_out_dir)

    # dir to save instance outputs
    instance_out_dir = os.path.join(args.output_dir, 'instance')
    PathManager.mkdirs(instance_out_dir)

    # dir to save panoptic outputs
    panoptic_out_dir = os.path.join(args.output_dir, 'panoptic')
    PathManager.mkdirs(panoptic_out_dir)

    # Test loop
    model.eval()

    # build image demo transform

    net_time = AverageMeter()
    post_time = AverageMeter()

    # dataset
    source = "/home/muyun99/Desktop/MyGithub/cnsoftbei-video/inference/input/video-clip_2-4.mp4"
    imgsz = 1024
    dataset = LoadImages(source, img_size=imgsz)

    try:
        with torch.no_grad():
            for i, (path, image, im0s, vid_cap) in enumerate(dataset):
                (_, raw_h, raw_w) = image.shape
                image = torch.from_numpy(image).to(device)
                image = image.float()
                image /= 255.0  # 0 - 255 to 0.0 - 1.0
                if image.ndimension() == 3:
                    image = image.unsqueeze(0)

                # network
                start_time = time.time()
                out_dict = model(image)
                torch.cuda.synchronize(device)
                net_time.update(time.time() - start_time)

                # post-processing
                start_time = time.time()
                semantic_pred = get_semantic_segmentation(out_dict['semantic'])
                panoptic_pred, center_pred = get_panoptic_segmentation(
                    semantic_pred,
                    out_dict['center'],
                    out_dict['offset'],
                    thing_list=data_loader.dataset.thing_list,
                    label_divisor=data_loader.dataset.label_divisor,
                    stuff_area=config.POST_PROCESSING.STUFF_AREA,
                    void_label=(data_loader.dataset.label_divisor *
                                data_loader.dataset.ignore_label),
                    threshold=config.POST_PROCESSING.CENTER_THRESHOLD,
                    nms_kernel=config.POST_PROCESSING.NMS_KERNEL,
                    top_k=config.POST_PROCESSING.TOP_K_INSTANCE,
                    foreground_mask=None)
                torch.cuda.synchronize(device)
                post_time.update(time.time() - start_time)

                logger.info(
                    'Network Time: {net_time.val:.3f}s ({net_time.avg:.3f}s)\t'
                    'Post-processing Time: {post_time.val:.3f}s ({post_time.avg:.3f}s)\t'
                    .format(net_time=net_time, post_time=post_time))

                # save predictions
                semantic_pred = semantic_pred.squeeze(0).cpu().numpy()
                panoptic_pred = panoptic_pred.squeeze(0).cpu().numpy()

                # crop predictions
                semantic_pred = semantic_pred[:raw_h, :raw_w]
                panoptic_pred = panoptic_pred[:raw_h, :raw_w]

                # Raw outputs
                save_debug_images(
                    dataset=data_loader.dataset,
                    batch_images=image,
                    batch_targets={},
                    batch_outputs=out_dict,
                    out_dir=raw_out_dir,
                    iteration=i,
                    target_keys=[],
                    output_keys=['semantic', 'center', 'offset'],
                    is_train=False,
                )

                save_annotation(
                    semantic_pred,
                    semantic_out_dir,
                    'semantic_pred_%d' % i,
                    add_colormap=True,
                    colormap=data_loader.dataset.create_label_colormap())
                pan_to_sem = panoptic_pred // data_loader.dataset.label_divisor
                save_annotation(
                    pan_to_sem,
                    semantic_out_dir,
                    'pan_to_sem_pred_%d' % i,
                    add_colormap=True,
                    colormap=data_loader.dataset.create_label_colormap())
                ins_id = panoptic_pred % data_loader.dataset.label_divisor
                pan_to_ins = panoptic_pred.copy()
                pan_to_ins[ins_id == 0] = 0
                save_instance_annotation(pan_to_ins, instance_out_dir,
                                         'pan_to_ins_pred_%d' % i)
                save_panoptic_annotation(
                    panoptic_pred,
                    panoptic_out_dir,
                    'panoptic_pred_%d' % i,
                    label_divisor=data_loader.dataset.label_divisor,
                    colormap=data_loader.dataset.create_label_colormap())
    except Exception:
        logger.exception("Exception during demo:")
        raise
    finally:
        logger.info("Demo finished.")

    vid_path = args.output_dir
    pic2video(filelist=glob.glob(semantic_out_dir + "pan_to_sem_pred_*"),
              vid_path=os.path.join(vid_path, "semantic.mp4"),
              size=(1024, 576))
    pic2video(filelist=glob.glob(instance_out_dir + "pan_to_ins_pred_*"),
              vid_path=os.path.join(vid_path, "instance.mp4"),
              size=(1024, 576))
    pic2video(filelist=glob.glob(panoptic_out_dir + "panoptic_pred_*"),
              vid_path=os.path.join(vid_path, "panoptic.mp4"),
              size=(1024, 576))
    logger.info("Video saved")
예제 #7
0
def main():
    args = parse_args()

    logger = logging.getLogger('segmentation')
    if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called
        setup_logger(output=config.OUTPUT_DIR, distributed_rank=args.local_rank)

    # logger.info(pprint.pformat(args))
    # logger.info(config)

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = len(gpus) > 1
    device = torch.device('cuda:{}'.format(args.local_rank))
    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://",
        )

    # build model
    model = build_segmentation_model_from_cfg(config)
    # logger.info("Model:\n{}".format(model))

    logger.info("Rank of current process: {}. World size: {}".format(comm.get_rank(), comm.get_world_size()))

    if distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(device)
    if comm.get_world_size() > 1:
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank
        )

    data_loader = build_train_loader_from_cfg(config)
    optimizer = build_optimizer(config, model)
    lr_scheduler = build_lr_scheduler(config, optimizer)

    data_loader_iter = iter(data_loader)

    start_iter = 0
    max_iter = config.TRAIN.MAX_ITER
    best_param_group_id = get_lr_group_id(optimizer)

    # initialize model
    if os.path.isfile(config.MODEL.WEIGHTS):
        model_weights = torch.load(config.MODEL.WEIGHTS)
        get_module(model, distributed).load_state_dict(model_weights, strict=False)
        logger.info('Pre-trained model from {}'.format(config.MODEL.WEIGHTS))
    elif config.MODEL.BACKBONE.PRETRAINED:
        if os.path.isfile(config.MODEL.BACKBONE.WEIGHTS):
            pretrained_weights = torch.load(config.MODEL.BACKBONE.WEIGHTS)
            get_module(model, distributed).backbone.load_state_dict(pretrained_weights, strict=False)
            logger.info('Pre-trained backbone from {}'.format(config.MODEL.BACKBONE.WEIGHTS))
        else:
            logger.info('No pre-trained weights for backbone, training from scratch.')

    # load model
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(config.OUTPUT_DIR, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            start_iter = checkpoint['start_iter']
            get_module(model, distributed).load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            logger.info('Loaded checkpoint (starting from iter {})'.format(checkpoint['start_iter']))

    data_time = AverageMeter()
    batch_time = AverageMeter()
    loss_meter = AverageMeter()

    # 显示模型的参数量
    def get_parameter_number(net):
        total_num = sum(p.numel() for p in net.parameters())
        trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
        # return {'Total': total_num/1000000, 'Trainable': trainable_num/1000000}
        logger.info('Total:{}M, Trainable:{}M'.format(total_num/1000000, trainable_num/1000000))
    print(get_parameter_number(model))

    # Debug output.
    if config.DEBUG.DEBUG:
        debug_out_dir = os.path.join(config.OUTPUT_DIR, 'debug_train')
        PathManager.mkdirs(debug_out_dir)

    # Train loop.
    try:
        for i in range(start_iter, max_iter):
            # data
            start_time = time.time()
            data = next(data_loader_iter)
            if not distributed:
                data = to_cuda(data, device)
            data_time.update(time.time() - start_time)
            # 取出mini-bach的数据和标签
            image = data.pop('image')
            label = data.pop('label')
            # import imageio
            # import numpy as np
            # print(label.shape)
            # label_image = np.array(label.cpu()[0])
            # print(label_image.shape)
            # imageio.imwrite('%s/%d_%s.png' % ('./', 1, 'debug_batch_label'), label_image.transpose(1, 2, 0))
            # 向前传播
            out_dict = model(image, data)
            # 计算代价函数
            loss = out_dict['loss']
            # 清零梯度准备计算
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 更新训练参数
            optimizer.step()
            # Get lr.
            lr = optimizer.param_groups[best_param_group_id]["lr"]
            lr_scheduler.step()

            batch_time.update(time.time() - start_time)
            loss_meter.update(loss.detach().cpu().item(), image.size(0))

            if i == 0 or (i + 1) % config.PRINT_FREQ == 0:
                msg = '[{0}/{1}] LR: {2:.7f}\t' \
                      'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                      'Data: {data_time.val:.3f}s ({data_time.avg:.3f}s)\t'.format(
                        i + 1, max_iter, lr, batch_time=batch_time, data_time=data_time)
                msg += get_loss_info_str(get_module(model, distributed).loss_meter_dict)
                logger.info(msg)
            if i == 0 or (i + 1) % config.DEBUG.DEBUG_FREQ == 0:
                if comm.is_main_process() and config.DEBUG.DEBUG:
                    save_debug_images(
                        dataset=data_loader.dataset,
                        label=label,
                        batch_images=image,
                        batch_targets=data,
                        batch_outputs=out_dict,
                        out_dir=debug_out_dir,
                        iteration=i,
                        target_keys=config.DEBUG.TARGET_KEYS,
                        output_keys=config.DEBUG.OUTPUT_KEYS,
                        iteration_to_remove=i - config.DEBUG.KEEP_INTERVAL
                    )
            if i == 0 or (i + 1) % config.CKPT_FREQ == 0:
                if comm.is_main_process():
                    torch.save({
                        'start_iter': i + 1,
                        'state_dict': get_module(model, distributed).state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                    }, os.path.join(config.OUTPUT_DIR, 'checkpoint.pth.tar'))
    except Exception:
        logger.exception("Exception during training:")
        raise
    finally:
        if comm.is_main_process():
            torch.save(get_module(model, distributed).state_dict(),
                       os.path.join(config.OUTPUT_DIR, 'final_state.pth'))
        logger.info("Training finished.")
def main():
    args = parse_args()

    logger = logging.getLogger('segmentation_test')
    if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called
        setup_logger(output=config.OUTPUT_DIR, name='segmentation_test')

    logger.info(pprint.pformat(args))
    logger.info(config)

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.TEST.GPUS)
    if len(gpus) > 1:
        raise ValueError('Test only supports single core.')
    device = torch.device('cuda:{}'.format(gpus[0]))

    # build model
    model = build_segmentation_model_from_cfg(config)

    # Change ASPP image pooling
    output_stride = 2 ** (5 - sum(config.MODEL.BACKBONE.DILATION))
    train_crop_h, train_crop_w = config.TEST.CROP_SIZE
    scale = 1. / output_stride
    pool_h = int((float(train_crop_h) - 1.0) * scale + 1.0)
    pool_w = int((float(train_crop_w) - 1.0) * scale + 1.0)

    model.set_image_pooling((pool_h, pool_w))

    logger.info("Model:\n{}".format(model))
    model = model.to(device)

    # build data_loader
    data_loader = build_test_loader_from_cfg(config)

    # load model
    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(config.OUTPUT_DIR, 'final_state.pth')

    if os.path.isfile(model_state_file):
        model_weights = torch.load(model_state_file)
        if 'state_dict' in model_weights.keys():
            model_weights = model_weights['state_dict']
            logger.info('Evaluating a intermediate checkpoint.')
        model.load_state_dict(model_weights, strict=True)
        logger.info('Test model loaded from {}'.format(model_state_file))
    else:
        if not config.DEBUG.DEBUG:
            raise ValueError('Cannot find test model.')

    data_time = AverageMeter()
    net_time = AverageMeter()
    post_time = AverageMeter()
    timing_warmup_iter = 10

    semantic_metric = SemanticEvaluator(
        num_classes=data_loader.dataset.num_classes,
        ignore_label=data_loader.dataset.ignore_label,
        output_dir=os.path.join(config.OUTPUT_DIR, config.TEST.SEMANTIC_FOLDER),
        train_id_to_eval_id=data_loader.dataset.train_id_to_eval_id()
    )

    instance_metric = None
    panoptic_metric = None

    if config.TEST.EVAL_INSTANCE:
        if 'cityscapes' in config.DATASET.DATASET:
            instance_metric = CityscapesInstanceEvaluator(
                output_dir=os.path.join(config.OUTPUT_DIR, config.TEST.INSTANCE_FOLDER),
                train_id_to_eval_id=data_loader.dataset.train_id_to_eval_id(),
                gt_dir=os.path.join(config.DATASET.ROOT, 'gtFine', config.DATASET.TEST_SPLIT)
            )
        elif 'coco' in config.DATASET.DATASET:
            instance_metric = COCOInstanceEvaluator(
                output_dir=os.path.join(config.OUTPUT_DIR, config.TEST.INSTANCE_FOLDER),
                train_id_to_eval_id=data_loader.dataset.train_id_to_eval_id(),
                gt_dir=os.path.join(config.DATASET.ROOT, 'annotations',
                                    'instances_{}.json'.format(config.DATASET.TEST_SPLIT))
            )
        else:
            raise ValueError('Undefined evaluator for dataset {}'.format(config.DATASET.DATASET))

    if config.TEST.EVAL_PANOPTIC:
        if 'cityscapes' in config.DATASET.DATASET:
            panoptic_metric = CityscapesPanopticEvaluator(
                output_dir=os.path.join(config.OUTPUT_DIR, config.TEST.PANOPTIC_FOLDER),
                train_id_to_eval_id=data_loader.dataset.train_id_to_eval_id(),
                label_divisor=data_loader.dataset.label_divisor,
                void_label=data_loader.dataset.label_divisor * data_loader.dataset.ignore_label,
                gt_dir=config.DATASET.ROOT,
                split=config.DATASET.TEST_SPLIT,
                num_classes=data_loader.dataset.num_classes
            )
        elif 'coco' in config.DATASET.DATASET:
            panoptic_metric = COCOPanopticEvaluator(
                output_dir=os.path.join(config.OUTPUT_DIR, config.TEST.PANOPTIC_FOLDER),
                train_id_to_eval_id=data_loader.dataset.train_id_to_eval_id(),
                label_divisor=data_loader.dataset.label_divisor,
                void_label=data_loader.dataset.label_divisor * data_loader.dataset.ignore_label,
                gt_dir=config.DATASET.ROOT,
                split=config.DATASET.TEST_SPLIT,
                num_classes=data_loader.dataset.num_classes
            )
        else:
            raise ValueError('Undefined evaluator for dataset {}'.format(config.DATASET.DATASET))

    foreground_metric = None
    if config.TEST.EVAL_FOREGROUND:
        foreground_metric = SemanticEvaluator(
            num_classes=2,
            ignore_label=data_loader.dataset.ignore_label,
            output_dir=os.path.join(config.OUTPUT_DIR, config.TEST.FOREGROUND_FOLDER)
        )

    image_filename_list = [
        os.path.splitext(os.path.basename(ann))[0] for ann in data_loader.dataset.ann_list]

    # Debug output.
    if config.TEST.DEBUG:
        debug_out_dir = os.path.join(config.OUTPUT_DIR, 'debug_test')
        PathManager.mkdirs(debug_out_dir)

    # Train loop.
    try:
        model.eval()
        with torch.no_grad():
            for i, data in enumerate(data_loader):
                if i == timing_warmup_iter:
                    data_time.reset()
                    net_time.reset()
                    post_time.reset()

                # data
                start_time = time.time()
                for key in data.keys():
                    try:
                        data[key] = data[key].to(device)
                    except:
                        pass

                image = data.pop('image')
                torch.cuda.synchronize(device)
                data_time.update(time.time() - start_time)

                start_time = time.time()
                # out_dict = model(image, data)
                out_dict = model(image)
                torch.cuda.synchronize(device)
                net_time.update(time.time() - start_time)

                start_time = time.time()
                semantic_pred = get_semantic_segmentation(out_dict['semantic'])
                if 'foreground' in out_dict:
                    foreground_pred = get_semantic_segmentation(out_dict['foreground'])
                else:
                    foreground_pred = None

                # Oracle experiment
                if config.TEST.ORACLE_SEMANTIC:
                    # Use predicted semantic for foreground
                    foreground_pred = torch.zeros_like(semantic_pred)
                    for thing_class in data_loader.dataset.thing_list:
                        foreground_pred[semantic_pred == thing_class] = 1
                    # Use gt semantic
                    semantic_pred = data['semantic']
                    # Set it to a stuff label
                    stuff_label = 0
                    while stuff_label in data_loader.dataset.thing_list:
                        stuff_label += 1
                    semantic_pred[semantic_pred == data_loader.dataset.ignore_label] = stuff_label
                if config.TEST.ORACLE_FOREGROUND:
                    foreground_pred = data['foreground']
                if config.TEST.ORACLE_CENTER:
                    out_dict['center'] = data['center']
                if config.TEST.ORACLE_OFFSET:
                    out_dict['offset'] = data['offset']

                if config.TEST.EVAL_INSTANCE or config.TEST.EVAL_PANOPTIC:
                    panoptic_pred, center_pred = get_panoptic_segmentation(
                        semantic_pred,
                        out_dict['center'],
                        out_dict['offset'],
                        thing_list=data_loader.dataset.thing_list,
                        label_divisor=data_loader.dataset.label_divisor,
                        stuff_area=config.POST_PROCESSING.STUFF_AREA,
                        void_label=(
                                data_loader.dataset.label_divisor *
                                data_loader.dataset.ignore_label),
                        threshold=config.POST_PROCESSING.CENTER_THRESHOLD,
                        nms_kernel=config.POST_PROCESSING.NMS_KERNEL,
                        top_k=config.POST_PROCESSING.TOP_K_INSTANCE,
                        foreground_mask=foreground_pred)
                else:
                    panoptic_pred = None
                torch.cuda.synchronize(device)
                post_time.update(time.time() - start_time)
                logger.info('[{}/{}]\t'
                            'Data Time: {data_time.val:.3f}s ({data_time.avg:.3f}s)\t'
                            'Network Time: {net_time.val:.3f}s ({net_time.avg:.3f}s)\t'
                            'Post-processing Time: {post_time.val:.3f}s ({post_time.avg:.3f}s)\t'.format(
                             i, len(data_loader), data_time=data_time, net_time=net_time, post_time=post_time))

                semantic_pred = semantic_pred.squeeze(0).cpu().numpy()
                if panoptic_pred is not None:
                    panoptic_pred = panoptic_pred.squeeze(0).cpu().numpy()
                if foreground_pred is not None:
                    foreground_pred = foreground_pred.squeeze(0).cpu().numpy()

                # Crop padded regions.
                image_size = data['size'].squeeze(0).cpu().numpy()
                semantic_pred = semantic_pred[:image_size[0], :image_size[1]]
                if panoptic_pred is not None:
                    panoptic_pred = panoptic_pred[:image_size[0], :image_size[1]]
                if foreground_pred is not None:
                    foreground_pred = foreground_pred[:image_size[0], :image_size[1]]

                # Resize back to the raw image size.
                raw_image_size = data['raw_size'].squeeze(0).cpu().numpy()
                if raw_image_size[0] != image_size[0] or raw_image_size[1] != image_size[1]:
                    semantic_pred = cv2.resize(semantic_pred.astype(np.float), (raw_image_size[1], raw_image_size[0]),
                                               interpolation=cv2.INTER_NEAREST).astype(np.int32)
                    if panoptic_pred is not None:
                        panoptic_pred = cv2.resize(panoptic_pred.astype(np.float),
                                                   (raw_image_size[1], raw_image_size[0]),
                                                   interpolation=cv2.INTER_NEAREST).astype(np.int32)
                    if foreground_pred is not None:
                        foreground_pred = cv2.resize(foreground_pred.astype(np.float),
                                                     (raw_image_size[1], raw_image_size[0]),
                                                     interpolation=cv2.INTER_NEAREST).astype(np.int32)

                # Evaluates semantic segmentation.
                semantic_metric.update(semantic_pred,
                                       data['raw_label'].squeeze(0).cpu().numpy(),
                                       image_filename_list[i])

                # Optional: evaluates instance segmentation.
                if instance_metric is not None:
                    raw_semantic = F.softmax(out_dict['semantic'][:, :, :image_size[0], :image_size[1]], dim=1)
                    center_hmp = out_dict['center'][:, :, :image_size[0], :image_size[1]]
                    if raw_image_size[0] != image_size[0] or raw_image_size[1] != image_size[1]:
                        raw_semantic = F.interpolate(raw_semantic,
                                                     size=(raw_image_size[0], raw_image_size[1]),
                                                     mode='bilinear',
                                                     align_corners=False)  # Consistent with OpenCV.
                        center_hmp = F.interpolate(center_hmp,
                                                   size=(raw_image_size[0], raw_image_size[1]),
                                                   mode='bilinear',
                                                   align_corners=False)  # Consistent with OpenCV.

                    raw_semantic = raw_semantic.squeeze(0).cpu().numpy()
                    center_hmp = center_hmp.squeeze(1).squeeze(0).cpu().numpy()

                    instances = get_cityscapes_instance_format(panoptic_pred,
                                                               raw_semantic,
                                                               center_hmp,
                                                               label_divisor=data_loader.dataset.label_divisor,
                                                               score_type=config.TEST.INSTANCE_SCORE_TYPE)
                    instance_metric.update(instances, image_filename_list[i])

                # Optional: evaluates panoptic segmentation.
                if panoptic_metric is not None:
                    image_id = '_'.join(image_filename_list[i].split('_')[:3])
                    panoptic_metric.update(panoptic_pred,
                                           image_filename=image_filename_list[i],
                                           image_id=image_id)

                # Optional: evaluates foreground segmentation.
                if foreground_metric is not None:
                    semantic_label = data['raw_label'].squeeze(0).cpu().numpy()
                    foreground_label = np.zeros_like(semantic_label)
                    for sem_lab in np.unique(semantic_label):
                        # Both `stuff` and `ignore` are background.
                        if sem_lab in data_loader.dataset.thing_list:
                            foreground_label[semantic_label == sem_lab] = 1

                    # Use semantic segmentation as foreground segmentation.
                    if foreground_pred is None:
                        foreground_pred = np.zeros_like(semantic_pred)
                        for sem_lab in np.unique(semantic_pred):
                            if sem_lab in data_loader.dataset.thing_list:
                                foreground_pred[semantic_pred == sem_lab] = 1

                    foreground_metric.update(foreground_pred,
                                             foreground_label,
                                             image_filename_list[i])

                if config.TEST.DEBUG:
                    # Raw outputs
                    save_debug_images(
                        dataset=data_loader.dataset,
                        batch_images=image,
                        batch_targets=data,
                        batch_outputs=out_dict,
                        out_dir=debug_out_dir,
                        iteration=i,
                        target_keys=config.DEBUG.TARGET_KEYS,
                        output_keys=config.DEBUG.OUTPUT_KEYS,
                        is_train=False,
                    )
                    if panoptic_pred is not None:
                        # Processed outputs
                        save_annotation(semantic_pred, debug_out_dir, 'semantic_pred_%d' % i,
                                        add_colormap=True, colormap=data_loader.dataset.create_label_colormap())
                        pan_to_sem = panoptic_pred // data_loader.dataset.label_divisor
                        save_annotation(pan_to_sem, debug_out_dir, 'pan_to_sem_pred_%d' % i,
                                        add_colormap=True, colormap=data_loader.dataset.create_label_colormap())
                        ins_id = panoptic_pred % data_loader.dataset.label_divisor
                        pan_to_ins = panoptic_pred.copy()
                        pan_to_ins[ins_id == 0] = 0
                        save_instance_annotation(pan_to_ins, debug_out_dir, 'pan_to_ins_pred_%d' % i)

                        save_panoptic_annotation(panoptic_pred, debug_out_dir, 'panoptic_pred_%d' % i,
                                                 label_divisor=data_loader.dataset.label_divisor,
                                                 colormap=data_loader.dataset.create_label_colormap())
    except Exception:
        logger.exception("Exception during testing:")
        raise
    finally:
        logger.info("Inference finished.")
        semantic_results = semantic_metric.evaluate()
        logger.info(semantic_results)
        if instance_metric is not None:
            instance_results = instance_metric.evaluate()
            logger.info(instance_results)
        if panoptic_metric is not None:
            panoptic_results = panoptic_metric.evaluate()
            logger.info(panoptic_results)
        if foreground_metric is not None:
            foreground_results = foreground_metric.evaluate()
            logger.info(foreground_results)
예제 #9
0
    def main(self, frame, index, total):
        self.model.eval()

        # build image demo transform
        transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(config.DATASET.MEAN, config.DATASET.STD)
        ])

        net_time = AverageMeter()
        post_time = AverageMeter()
        try:
            with torch.no_grad():
                raw_image = frame
                # pad image
                raw_shape = raw_image.shape[:2]
                raw_h = raw_shape[0]
                raw_w = raw_shape[1]
                new_h = (raw_h + 31) // 32 * 32 + 1
                new_w = (raw_w + 31) // 32 * 32 + 1
                input_image = np.zeros((new_h, new_w, 3), dtype=np.uint8)
                input_image[:, :] = config.DATASET.MEAN
                input_image[:raw_h, :raw_w, :] = raw_image

                image, _ = transforms(input_image, None)
                image = image.unsqueeze(0).to(self.device)

                # network
                start_time = time.time()
                out_dict = self.model(image)
                torch.cuda.synchronize(self.device)
                net_time.update(time.time() - start_time)

                # post-processing
                start_time = time.time()
                semantic_pred = get_semantic_segmentation(out_dict['semantic'])

                panoptic_pred, center_pred = get_panoptic_segmentation(
                    semantic_pred,
                    out_dict['center'],
                    out_dict['offset'],
                    thing_list=self.meta_dataset.thing_list,
                    label_divisor=self.meta_dataset.label_divisor,
                    stuff_area=config.POST_PROCESSING.STUFF_AREA,
                    void_label=(self.meta_dataset.label_divisor *
                                self.meta_dataset.ignore_label),
                    threshold=config.POST_PROCESSING.CENTER_THRESHOLD,
                    nms_kernel=config.POST_PROCESSING.NMS_KERNEL,
                    top_k=config.POST_PROCESSING.TOP_K_INSTANCE,
                    foreground_mask=None)
                torch.cuda.synchronize(self.device)
                post_time.update(time.time() - start_time)

                self.logger.info(
                    '[{}/{}]\t'
                    'Network Time: {net_time.val:.3f}s ({net_time.avg:.3f}s)\t'
                    'Post-processing Time: {post_time.val:.3f}s ({post_time.avg:.3f}s)\t'
                    .format(index,
                            total,
                            net_time=net_time,
                            post_time=post_time))

                # save predictions
                #semantic_pred = semantic_pred.squeeze(0).cpu().numpy()
                panoptic_pred = panoptic_pred.squeeze(0).cpu().numpy()

                # crop predictions
                #semantic_pred = semantic_pred[:raw_h, :raw_w]
                panoptic_pred = panoptic_pred[:raw_h, :raw_w]

                frame = creat_panoptic_annotation(
                    panoptic_pred,
                    label_divisor=self.meta_dataset.label_divisor,
                    colormap=self.meta_dataset.create_label_colormap(),
                    image=raw_image)
        except Exception:
            self.logger.exception("Exception during demo:")
            raise
        finally:
            self.logger.info("Demo finished.")
            return frame
예제 #10
0
def validate(val_loader, net, criterion, optimizer, epoch, train_args, restore,
             visualize):
    net.eval()

    val_loss = AverageMeter()
    inputs_all, gts_all, predictions_all = [], [], []

    for vi, data in enumerate(val_loader):
        inputs, gts = data
        N = inputs.size(0)
        inputs = Variable(inputs, volatile=True).cuda()
        gts = Variable(gts, volatile=True).cuda()

        outputs = net(inputs)
        predictions = outputs.data.max(1)[1].squeeze_(1).squeeze_(
            0).cpu().numpy()

        val_loss.update(criterion(outputs, gts).data[0] / N, N)

        if random.random() > train_args['val_img_sample_rate']:
            inputs_all.append(None)
        else:
            inputs_all.append(inputs.data.squeeze_(0).cpu())
        gts_all.append(gts.data.squeeze_(0).cpu().numpy())
        predictions_all.append(predictions)

    acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all,
                                              num_classes)

    if mean_iu > train_args['best_record']['mean_iu']:
        train_args['best_record']['val_loss'] = val_loss.avg
        train_args['best_record']['epoch'] = epoch
        train_args['best_record']['acc'] = acc
        train_args['best_record']['acc_cls'] = acc_cls
        train_args['best_record']['mean_iu'] = mean_iu
        train_args['best_record']['fwavacc'] = fwavacc
        snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
            epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc,
            optimizer.param_groups[1]['lr'])
        torch.save(net.state_dict(),
                   os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
        torch.save(
            optimizer.state_dict(),
            os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth'))

        if train_args['val_save_to_img_file']:
            to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch))
            check_mkdir(to_save_dir)

        val_visual = []
        for idx, data in enumerate(zip(inputs_all, gts_all, predictions_all)):
            if data[0] is None:
                continue
            input_pil = restore(data[0])
            gt_pil = colorize_mask(data[1])
            predictions_pil = colorize_mask(data[2])
            if train_args['val_save_to_img_file']:
                input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx))
                predictions_pil.save(
                    os.path.join(to_save_dir, '%d_prediction.png' % idx))
                gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx))
            val_visual.extend([
                visualize(input_pil.convert('RGB')),
                visualize(gt_pil.convert('RGB')),
                visualize(predictions_pil.convert('RGB'))
            ])
        val_visual = torch.stack(val_visual, 0)
        val_visual = vutils.make_grid(val_visual, nrow=3, padding=5)
        #writer.add_image(snapshot_name, val_visual)

    print(
        '--------------------------------------------------------------------')
    print(
        '[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]'
        % (epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))

    print(
        'best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]'
        % (train_args['best_record']['val_loss'],
           train_args['best_record']['acc'],
           train_args['best_record']['acc_cls'],
           train_args['best_record']['mean_iu'],
           train_args['best_record']['fwavacc'],
           train_args['best_record']['epoch']))

    print(
        '--------------------------------------------------------------------')

    #writer.add_scalar('val_loss', val_loss.avg, epoch)
    #writer.add_scalar('acc', acc, epoch)
    #writer.add_scalar('acc_cls', acc_cls, epoch)
    #writer.add_scalar('mean_iu', mean_iu, epoch)
    #writer.add_scalar('fwavacc', fwavacc, epoch)
    #writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch)

    return val_loss.avg