示例#1
0
def main():
    """Training for softmax classifier only.
  """
    # Retreve experiment configurations.
    args = parse_args('Training for softmax classifier only.')

    # Retrieve GPU informations.
    device_ids = [int(i) for i in config.gpus.split(',')]
    gpu_ids = [torch.device('cuda', i) for i in device_ids]
    num_gpus = len(gpu_ids)

    # Create logger and tensorboard writer.
    summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir)
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)

    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    optimizer_path_template = os.path.join(args.snapshot_dir,
                                           'model-{:d}.state.pth')

    # Create data loaders.
    train_dataset = ListTagClassifierDataset(
        data_dir=args.data_dir,
        data_list=args.data_list,
        img_mean=config.network.pixel_means,
        img_std=config.network.pixel_stds,
        size=config.train.crop_size,
        random_crop=config.train.random_crop,
        random_scale=config.train.random_scale,
        random_mirror=config.train.random_mirror,
        random_grayscale=True,
        random_blur=True,
        training=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=config.train.shuffle,
        num_workers=num_gpus * config.num_threads,
        drop_last=False,
        collate_fn=train_dataset.collate_fn)

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'softmax_classifier':
        prediction_model = softmax_classifier(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    # Use customized optimizer and pass lr=1 to support different lr for
    # different weights.
    optimizer = SGD(embedding_model.get_params_lr() +
                    prediction_model.get_params_lr(),
                    lr=1,
                    momentum=config.train.momentum,
                    weight_decay=config.train.weight_decay)
    optimizer.zero_grad()

    # Load pre-trained weights.
    curr_iter = config.train.begin_iteration
    if config.network.pretrained:
        print('Loading pre-trained model: {:s}'.format(
            config.network.pretrained))
        embedding_model.load_state_dict(torch.load(
            config.network.pretrained)['embedding_model'],
                                        resume=True)
    else:
        raise ValueError('Pre-trained model is required.')

    # Distribute model weights to multi-gpus.
    embedding_model = DataParallel(embedding_model,
                                   device_ids=device_ids,
                                   gather_output=False)
    prediction_model = DataParallel(prediction_model,
                                    device_ids=device_ids,
                                    gather_output=False)

    embedding_model.eval()
    prediction_model.train()
    print(embedding_model)
    print(prediction_model)

    # Create memory bank.
    memory_banks = {}

    # start training
    train_iterator = train_loader.__iter__()
    iterator_index = 0
    pbar = tqdm(range(curr_iter, config.train.max_iteration))
    for curr_iter in pbar:
        # Check if the rest of datas is enough to iterate through;
        # otherwise, re-initiate the data iterator.
        if iterator_index + num_gpus >= len(train_loader):
            train_iterator = train_loader.__iter__()
            iterator_index = 0

        # Feed-forward.
        image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu(
            train_iterator, gpu_ids)
        iterator_index += num_gpus

        # Generate embeddings, clustering and prototypes.
        with torch.no_grad():
            embeddings = embedding_model(*zip(image_batch, label_batch))

        # Compute loss.
        outputs = prediction_model(*zip(embeddings, label_batch))
        outputs = scatter_gather.gather(outputs, gpu_ids[0])
        losses = []
        for k in ['sem_ann_loss']:
            loss = outputs.get(k, None)
            if loss is not None:
                outputs[k] = loss.mean()
                losses.append(outputs[k])
        loss = sum(losses)
        acc = outputs['accuracy'].mean()

        # Backward propogation.
        if config.train.lr_policy == 'step':
            lr = train_utils.lr_step(config.train.base_lr, curr_iter,
                                     config.train.decay_iterations,
                                     config.train.warmup_iteration)
        else:
            lr = train_utils.lr_poly(config.train.base_lr, curr_iter,
                                     config.train.max_iteration,
                                     config.train.warmup_iteration)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step(lr)

        # Snapshot the trained model.
        if ((curr_iter + 1) % config.train.snapshot_step == 0
                or curr_iter == config.train.max_iteration - 1):
            model_state_dict = {
                'embedding_model': embedding_model.module.state_dict(),
                'prediction_model': prediction_model.module.state_dict()
            }
            torch.save(model_state_dict, model_path_template.format(curr_iter))
            torch.save(optimizer.state_dict(),
                       optimizer_path_template.format(curr_iter))

        # Print loss in the progress bar.
        line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format(
            loss.item(), acc.item(), lr)
        pbar.set_description(line)
示例#2
0
def main():
    """Inference for semantic segmentation.
  """
    # Retreve experiment configurations.
    args = parse_args('Inference for semantic segmentation.')

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')
    os.makedirs(semantic_dir, exist_ok=True)
    os.makedirs(semantic_rgb_dir, exist_ok=True)

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    prediction_model = softmax_classifier(config).cuda()
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    prediction_model.load_state_dict(
        torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            image_batch, label_batch, _ = test_dataset[data_index]
            image_h, image_w = image_batch['image'].shape[-2:]
            batches = other_utils.create_image_pyramid(
                image_batch,
                label_batch,
                scales=[0.5, 0.75, 1, 1.25, 1.5],
                is_flip=True)

            semantic_logits = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).cuda()
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                # Create the ending index of each patch.
                stride_h, stride_w = config.test.stride
                crop_h, crop_w = config.test.crop_size
                npatches_h = math.ceil(1.0 *
                                       (pad_image_h - crop_h) / stride_h) + 1
                npatches_w = math.ceil(1.0 *
                                       (pad_image_w - crop_w) / stride_w) + 1
                patch_ind_h = np.linspace(crop_h,
                                          pad_image_h,
                                          npatches_h,
                                          dtype=np.int32)
                patch_ind_w = np.linspace(crop_w,
                                          pad_image_w,
                                          npatches_w,
                                          dtype=np.int32)

                # Create place holder for full-resolution embeddings.
                semantic_logit = torch.FloatTensor(
                    1, config.dataset.num_classes, pad_image_h,
                    pad_image_w).zero_().to("cuda:0")
                counts = torch.FloatTensor(1, 1, pad_image_h,
                                           pad_image_w).zero_().to("cuda:0")
                for ind_h in patch_ind_h:
                    for ind_w in patch_ind_w:
                        sh, eh = ind_h - crop_h, ind_h
                        sw, ew = ind_w - crop_w, ind_w
                        crop_image_batch = {
                            k: v[:, :, sh:eh, sw:ew]
                            for k, v in image_batch.items()
                        }

                        # Feed-forward.
                        crop_embeddings = embedding_model(crop_image_batch,
                                                          resize_as_input=True)
                        crop_outputs = prediction_model(crop_embeddings)
                        semantic_logit[..., sh:eh, sw:ew] += crop_outputs[
                            'semantic_logit'].to("cuda:0")
                        counts[..., sh:eh, sw:ew] += 1
                semantic_logit /= counts
                semantic_logit = semantic_logit[
                    ..., :resize_image_h, :resize_image_w]
                semantic_logit = F.interpolate(semantic_logit,
                                               size=(image_h, image_w),
                                               mode='bilinear')
                semantic_logit = F.softmax(semantic_logit, dim=1)
                semantic_logit = semantic_logit.data.cpu().numpy().astype(
                    np.float32)
                if data_info['is_flip']:
                    semantic_logit = semantic_logit[..., ::-1]
                semantic_logits.append(semantic_logit)

            # Save semantic predictions.
            semantic_logits = np.concatenate(semantic_logits, axis=0)
            semantic_logits = np.sum(semantic_logits, axis=0)
            if semantic_logits is not None:
                semantic_pred = np.argmax(semantic_logits,
                                          axis=0).astype(np.uint8)

                semantic_pred_name = os.path.join(semantic_dir, base_name)
                Image.fromarray(semantic_pred,
                                mode='L').save(semantic_pred_name)

                semantic_pred_rgb = color_map[semantic_pred]
                semantic_pred_rgb_name = os.path.join(semantic_rgb_dir,
                                                      base_name)
                Image.fromarray(semantic_pred_rgb,
                                mode='RGB').save(semantic_pred_rgb_name)

            # Clean GPU memory cache to save more space.
            outputs = {}
            crop_embeddings = {}
            crop_outputs = {}
            torch.cuda.empty_cache()
示例#3
0
def main():
    """Generate pseudo labels by softmax classifier.
  """
    # Retreve experiment configurations.
    args = parse_args('Generate pseudo labels by softmax classifier.')

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    prediction_model = softmax_classifier(config).cuda()
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    prediction_model.load_state_dict(
        torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            original_image_batch, original_label_batch, _ = test_dataset[
                data_index]
            image_h, image_w = original_image_batch['image'].shape[-2:]

            lab_tags = np.unique(original_label_batch['semantic_label'])
            lab_tags = lab_tags[lab_tags < config.dataset.num_classes]
            label_tags = np.zeros((config.dataset.num_classes, ),
                                  dtype=np.bool)
            label_tags[lab_tags] = True
            label_tags = torch.from_numpy(label_tags).cuda()

            # Image resolution.
            batches = other_utils.create_image_pyramid(original_image_batch,
                                                       original_label_batch,
                                                       scales=[0.75, 1],
                                                       is_flip=True)

            affs = []
            semantic_probs = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).cuda()
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                embeddings = embedding_model(image_batch, resize_as_input=True)
                outputs = prediction_model(embeddings)

                embs = embeddings[
                    'embedding'][:, :, :resize_image_h, :resize_image_w]
                semantic_logit = outputs['semantic_logit'][
                    ..., :resize_image_h, :resize_image_w]
                if data_info['is_flip']:
                    embs = torch.flip(embs, dims=[3])
                    semantic_logit = torch.flip(semantic_logit, dims=[3])
                embs = F.interpolate(embs,
                                     size=(image_h // 8, image_w // 8),
                                     mode='bilinear')
                embs = embs / torch.norm(embs, dim=1)
                embs_flat = embs.view(embs.shape[1], -1)
                aff = torch.matmul(embs_flat.t(),
                                   embs_flat).mul_(5).add_(-5).exp_()
                affs.append(aff)

                semantic_logit = F.interpolate(semantic_logit,
                                               size=(image_h // 8,
                                                     image_w // 8),
                                               mode='bilinear')
                #semantic_prob = F.softmax(semantic_logit, dim=1)
                #semantic_probs.append(semantic_prob)
                semantic_probs.append(semantic_logit)

            cat_semantic_probs = torch.cat(semantic_probs, dim=0)
            #semantic_probs, _ = torch.max(cat_semantic_probs, dim=0)
            #semantic_probs[0] = torch.min(cat_semantic_probs[:, 0, :, :], dim=0)[0]
            semantic_probs = torch.mean(cat_semantic_probs, dim=0)
            semantic_probs = F.softmax(semantic_probs, dim=0)

            # normalize cam.
            max_prob = torch.max(semantic_probs.view(21, -1), dim=1)[0]
            cam_full_arr = semantic_probs / max_prob.view(21, 1, 1)

            cam_shape = cam_full_arr.shape[-2:]
            label_tags = (~label_tags).view(-1, 1,
                                            1).expand(-1, cam_shape[0],
                                                      cam_shape[1])
            cam_full_arr = cam_full_arr.masked_fill(label_tags, 0)
            if TH is not None:
                cam_full_arr[0] = TH

            aff = torch.mean(torch.stack(affs, dim=0), dim=0)

            # Start random walk.
            aff_mat = aff**20

            trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
            for _ in range(WALK_STEPS):
                trans_mat = torch.matmul(trans_mat, trans_mat)

            cam_vec = cam_full_arr.view(21, -1)
            cam_rw = torch.matmul(cam_vec, trans_mat)
            cam_rw = cam_rw.view(21, cam_shape[0], cam_shape[1])

            cam_rw = cam_rw.data.cpu().numpy()
            cam_rw = cv2.resize(cam_rw.transpose(1, 2, 0),
                                dsize=(image_w, image_h),
                                interpolation=cv2.INTER_LINEAR)
            cam_rw_pred = np.argmax(cam_rw, axis=-1).astype(np.uint8)

            # CRF
            #image = image_batch['image'].data.cpu().numpy().astype(np.float32)
            #image = image[0, :, :image_h, :image_w].transpose(1, 2, 0)
            #image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            #image += np.reshape(config.network.pixel_means, (1, 1, 3))
            #image = image * 255
            #image = image.astype(np.uint8)
            #cam_rw = postprocessor(image, cam_rw.transpose(2,0,1))

            #cam_rw_pred = np.argmax(cam_rw, axis=0).astype(np.uint8)

            # Save semantic predictions.
            semantic_pred = cam_rw_pred

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)
示例#4
0
def main():
    """Inference for semantic segmentation.
  """
    # Retreve experiment configurations.
    args = parse_args('Inference for semantic segmentation.')

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')
    if not os.path.isdir(semantic_dir):
        os.makedirs(semantic_dir)
    if not os.path.isdir(semantic_rgb_dir):
        os.makedirs(semantic_rgb_dir)

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    prediction_model = softmax_classifier(config).cuda()
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    prediction_model.load_state_dict(
        torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Start inferencing.
    for data_index in range(len(test_dataset)):
        # Image path.
        image_path = test_image_paths[data_index]
        base_name = os.path.basename(image_path).replace('.jpg', '.png')

        # Image resolution.
        image_batch, _, _ = test_dataset[data_index]
        image_h, image_w = image_batch['image'].shape[-2:]

        # Resize the input image.
        if config.test.image_size > 0:
            image_batch['image'] = transforms.resize_with_interpolation(
                image_batch['image'].transpose(1, 2, 0),
                config.test.image_size,
                method='bilinear').transpose(2, 0, 1)
        resize_image_h, resize_image_w = image_batch['image'].shape[-2:]

        # Crop and Pad the input image.
        image_batch['image'] = transforms.resize_with_pad(
            image_batch['image'].transpose(1, 2, 0),
            config.test.crop_size,
            image_pad_value=0).transpose(2, 0, 1)
        image_batch['image'] = torch.FloatTensor(
            image_batch['image'][np.newaxis, ...]).cuda()
        pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

        # Create the ending index of each patch.
        stride_h, stride_w = config.test.stride
        crop_h, crop_w = config.test.crop_size
        npatches_h = math.ceil(1.0 * (pad_image_h - crop_h) / stride_h) + 1
        npatches_w = math.ceil(1.0 * (pad_image_w - crop_w) / stride_w) + 1
        patch_ind_h = np.linspace(crop_h,
                                  pad_image_h,
                                  npatches_h,
                                  dtype=np.int32)
        patch_ind_w = np.linspace(crop_w,
                                  pad_image_w,
                                  npatches_w,
                                  dtype=np.int32)

        # Create place holder for full-resolution embeddings.
        outputs = {}
        with torch.no_grad():
            for ind_h in patch_ind_h:
                for ind_w in patch_ind_w:
                    sh, eh = ind_h - crop_h, ind_h
                    sw, ew = ind_w - crop_w, ind_w
                    crop_image_batch = {
                        k: v[:, :, sh:eh, sw:ew]
                        for k, v in image_batch.items()
                    }

                    # Feed-forward.
                    crop_embeddings = embedding_model(crop_image_batch,
                                                      resize_as_input=True)
                    crop_outputs = prediction_model(crop_embeddings)

                    for name, crop_out in crop_outputs.items():

                        if crop_out is not None:
                            if name not in outputs.keys():
                                output_shape = list(crop_out.shape)
                                output_shape[-2:] = pad_image_h, pad_image_w
                                outputs[name] = torch.zeros(
                                    output_shape, dtype=crop_out.dtype).cuda()
                            outputs[name][..., sh:eh, sw:ew] += crop_out

        # Save semantic predictions.
        semantic_logits = outputs.get('semantic_logit', None)
        if semantic_logits is not None:
            semantic_prob = F.softmax(semantic_logits, dim=1)
            semantic_prob = semantic_prob[
                0, :, :resize_image_h, :resize_image_w]
            semantic_prob = semantic_prob.data.cpu().numpy().astype(np.float32)

            # DenseCRF post-processing.
            image = image_batch['image'][0].data.cpu().numpy().astype(
                np.float32)
            image = image.transpose(1, 2, 0)
            image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            image += np.reshape(config.network.pixel_means, (1, 1, 3))
            image = image * 255
            image = image.astype(np.uint8)
            image = image[:resize_image_h, :resize_image_w, :]

            semantic_prob = postprocessor(image, semantic_prob)

            #semantic_pred = torch.argmax(semantic_logits, 1)
            semantic_pred = np.argmax(semantic_prob, axis=0).astype(np.uint8)
            #semantic_pred = (semantic_pred.view(pad_image_h, pad_image_w)
            #                              .cpu()
            #                              .data
            #                              .numpy()
            #                              .astype(np.uint8))
            #semantic_pred = semantic_pred[:resize_image_h, :resize_image_w]
            semantic_pred = cv2.resize(semantic_pred, (image_w, image_h),
                                       interpolation=cv2.INTER_NEAREST)

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)

            # Clean GPU memory cache to save more space.
            outputs = {}
            crop_embeddings = {}
            crop_outputs = {}
            torch.cuda.empty_cache()
示例#5
0
def main():
    """Training for pixel-wise embeddings by pixel-segment
  contrastive learning loss.
  """
    # Retreve experiment configurations.
    args = parse_args('Training for pixel-wise embeddings.')

    # Retrieve GPU informations.
    device_ids = [int(i) for i in config.gpus.split(',')]
    gpu_ids = [torch.device('cuda', i) for i in device_ids]
    num_gpus = len(gpu_ids)

    # Create logger and tensorboard writer.
    summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir)
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)

    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    optimizer_path_template = os.path.join(args.snapshot_dir,
                                           'model-{:d}.state.pth')

    # Create data loaders.
    train_dataset = ListTagDataset(data_dir=args.data_dir,
                                   data_list=args.data_list,
                                   img_mean=config.network.pixel_means,
                                   img_std=config.network.pixel_stds,
                                   size=config.train.crop_size,
                                   random_crop=config.train.random_crop,
                                   random_scale=config.train.random_scale,
                                   random_mirror=config.train.random_mirror,
                                   training=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=config.train.shuffle,
        num_workers=num_gpus * config.num_threads,
        drop_last=False,
        collate_fn=train_dataset.collate_fn)

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'segsort':
        prediction_model = segsort(config).cuda()
    elif config.network.prediction_types == 'softmax_classifier':
        prediction_model = softmax_classifier(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    # Use synchronize batchnorm.
    if config.network.use_syncbn:
        embedding_model = convert_model(embedding_model).cuda()
        prediction_model = convert_model(prediction_model).cuda()

    # Use customized optimizer and pass lr=1 to support different lr for
    # different weights.
    optimizer = SGD(embedding_model.get_params_lr() +
                    prediction_model.get_params_lr(),
                    lr=1,
                    momentum=config.train.momentum,
                    weight_decay=config.train.weight_decay)
    optimizer.zero_grad()

    # Load pre-trained weights.
    curr_iter = config.train.begin_iteration
    if config.train.resume:
        model_path = model_path_template.fromat(curr_iter)
        print('Resume training from {:s}'.format(model_path))
        embedding_model.load_state_dict(
            torch.load(model_path)['embedding_model'], resume=True)
        prediction_model.load_state_dict(
            torch.load(model_path)['prediction_model'], resume=True)
        optimizer.load_state_dict(
            torch.load(optimizer_path_template.format(curr_iter)))
    elif config.network.pretrained:
        print('Loading pre-trained model: {:s}'.format(
            config.network.pretrained))
        embedding_model.load_state_dict(torch.load(config.network.pretrained))
    else:
        print('Training from scratch')

    # Distribute model weights to multi-gpus.
    embedding_model = DataParallel(embedding_model,
                                   device_ids=device_ids,
                                   gather_output=False)
    prediction_model = DataParallel(prediction_model,
                                    device_ids=device_ids,
                                    gather_output=False)
    if config.network.use_syncbn:
        patch_replication_callback(embedding_model)
        patch_replication_callback(prediction_model)

    for module in embedding_model.modules():
        if isinstance(module, _BatchNorm) or isinstance(module, _ConvNd):
            print(module.training, module)
    print(embedding_model)
    print(prediction_model)

    # Create memory bank.
    memory_banks = {}

    # start training
    train_iterator = train_loader.__iter__()
    iterator_index = 0
    pbar = tqdm(range(curr_iter, config.train.max_iteration))
    for curr_iter in pbar:
        # Check if the rest of datas is enough to iterate through;
        # otherwise, re-initiate the data iterator.
        if iterator_index + num_gpus >= len(train_loader):
            train_iterator = train_loader.__iter__()
            iterator_index = 0

        # Feed-forward.
        image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu(
            train_iterator, gpu_ids)
        iterator_index += num_gpus

        # Generate embeddings, clustering and prototypes.
        embeddings = embedding_model(*zip(image_batch, label_batch))

        # Synchronize cluster indices and computer prototypes.
        c_inds = [emb['cluster_index'] for emb in embeddings]
        cb_inds = [emb['cluster_batch_index'] for emb in embeddings]
        cs_labs = [emb['cluster_semantic_label'] for emb in embeddings]
        ci_labs = [emb['cluster_instance_label'] for emb in embeddings]
        c_embs = [emb['cluster_embedding'] for emb in embeddings]
        c_embs_with_loc = [
            emb['cluster_embedding_with_loc'] for emb in embeddings
        ]
        (prototypes, prototypes_with_loc, prototype_semantic_labels,
         prototype_instance_labels, prototype_batch_indices,
         cluster_indices) = (
             model_utils.gather_clustering_and_update_prototypes(
                 c_embs, c_embs_with_loc, c_inds, cb_inds, cs_labs, ci_labs,
                 'cuda:{:d}'.format(num_gpus - 1)))

        for i in range(len(label_batch)):
            label_batch[i]['prototype'] = prototypes[i]
            label_batch[i]['prototype_with_loc'] = prototypes_with_loc[i]
            label_batch[i][
                'prototype_semantic_label'] = prototype_semantic_labels[i]
            label_batch[i][
                'prototype_instance_label'] = prototype_instance_labels[i]
            label_batch[i]['prototype_batch_index'] = prototype_batch_indices[
                i]
            embeddings[i]['cluster_index'] = cluster_indices[i]

        semantic_tags = model_utils.gather_and_update_datas(
            [lab['semantic_tag'] for lab in label_batch],
            'cuda:{:d}'.format(num_gpus - 1))
        for i in range(len(label_batch)):
            label_batch[i]['semantic_tag'] = semantic_tags[i]
            label_batch[i]['prototype_semantic_tag'] = torch.index_select(
                semantic_tags[i], 0, label_batch[i]['prototype_batch_index'])

        # Add memory bank to label batch.
        for k in memory_banks.keys():
            for i in range(len(label_batch)):
                assert (label_batch[i].get(k, None) is None)
                label_batch[i][k] = [m.to(gpu_ids[i]) for m in memory_banks[k]]

        # Compute loss.
        outputs = prediction_model(*zip(embeddings, label_batch))
        outputs = scatter_gather.gather(outputs, gpu_ids[0])
        losses = []
        for k in [
                'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss', 'feat_aff_loss'
        ]:
            loss = outputs.get(k, None)
            if loss is not None:
                outputs[k] = loss.mean()
                losses.append(outputs[k])
        loss = sum(losses)
        acc = outputs['accuracy'].mean()

        # Write to tensorboard summary.
        writer = (summary_writer if curr_iter %
                  config.train.tensorboard_step == 0 else None)
        if writer is not None:
            summary_vis = []
            summary_val = {}
            # Gather labels to cpu.
            cpu_label_batch = scatter_gather.gather(label_batch, -1)
            summary_vis.append(
                vis_utils.convert_label_to_color(
                    cpu_label_batch['semantic_label'], color_map))
            summary_vis.append(
                vis_utils.convert_label_to_color(
                    cpu_label_batch['instance_label'], color_map))

            # Gather outputs to cpu.
            vis_names = ['embedding']
            cpu_embeddings = scatter_gather.gather(
                [{k: emb.get(k, None)
                  for k in vis_names} for emb in embeddings], -1)
            for vis_name in vis_names:
                if cpu_embeddings.get(vis_name, None) is not None:
                    summary_vis.append(
                        vis_utils.embedding_to_rgb(cpu_embeddings[vis_name],
                                                   'pca'))

            val_names = [
                'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss',
                'feat_aff_loss', 'accuracy'
            ]
            for val_name in val_names:
                if outputs.get(val_name, None) is not None:
                    summary_val[val_name] = outputs[val_name].mean().to('cpu')

            vis_utils.write_image_to_tensorboard(summary_writer, summary_vis,
                                                 summary_vis[-1].shape[-2:],
                                                 curr_iter)
            vis_utils.write_scalars_to_tensorboard(summary_writer, summary_val,
                                                   curr_iter)

        # Backward propogation.
        if config.train.lr_policy == 'step':
            lr = train_utils.lr_step(config.train.base_lr, curr_iter,
                                     config.train.decay_iterations,
                                     config.train.warmup_iteration)
        else:
            lr = train_utils.lr_poly(config.train.base_lr, curr_iter,
                                     config.train.max_iteration,
                                     config.train.warmup_iteration)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step(lr)

        # Update memory banks.
        with torch.no_grad():
            for k in label_batch[0].keys():
                if 'prototype' in k and 'memory' not in k:
                    memory = label_batch[0][k].clone().detach()
                    memory_key = 'memory_' + k
                    if memory_key not in memory_banks.keys():
                        memory_banks[memory_key] = []
                    memory_banks[memory_key].append(memory)
                    if len(memory_banks[memory_key]
                           ) > config.train.memory_bank_size:
                        memory_banks[memory_key] = memory_banks[memory_key][1:]

            # Update batch labels.
            for k in ['memory_prototype_batch_index']:
                memory_labels = memory_banks.get(k, None)
                if memory_labels is not None:
                    for i, memory_label in enumerate(memory_labels):
                        memory_labels[i] += config.train.batch_size * num_gpus

        # Snapshot the trained model.
        if ((curr_iter + 1) % config.train.snapshot_step == 0
                or curr_iter == config.train.max_iteration - 1):
            model_state_dict = {
                'embedding_model': embedding_model.module.state_dict(),
                'prediction_model': prediction_model.module.state_dict()
            }
            torch.save(model_state_dict, model_path_template.format(curr_iter))
            torch.save(optimizer.state_dict(),
                       optimizer_path_template.format(curr_iter))

        # Print loss in the progress bar.
        line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format(
            loss.item(), acc.item(), lr)
        pbar.set_description(line)