Example #1
0
    def train_epoch(self, train_db, epoch):
        syn_db = synthesis_loader(train_db)
        loader = DataLoader(syn_db,
                            batch_size=self.cfg.batch_size,
                            shuffle=True,
                            num_workers=self.cfg.num_workers,
                            pin_memory=True)
        errors_list = []

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        self.net.train()
        net.lossnet.eval()
        for cnt, batched in enumerate(loader):
            ##################################################################
            ## Batched data
            ##################################################################
            proposals, gt_images, gt_labels = self.batch_data(batched)
            gt_images = gt_images.permute(0, 3, 1, 2)

            weights = None
            if self.cfg.weighted_synthesis:
                weights = proposals[:, :, :, -4].clone().detach()
                weights = 0.5 * (1.0 + weights)

            ##################################################################
            ## Train one step
            ##################################################################
            self.net.zero_grad()
            synthesized_images, synthesized_labels, synthesized_features, gt_features = \
                self.net(proposals, True, gt_images)
            loss, losses = self.compute_loss(synthesized_images, gt_images,
                                             synthesized_features, gt_features,
                                             synthesized_labels, gt_labels,
                                             weights)
            loss.backward()
            self.optimizer.step()

            ##################################################################
            ## Collect info
            ##################################################################
            errors_list.append(losses.cpu().data.numpy().flatten())

            ##################################################################
            ## Print info
            ##################################################################
            if cnt % self.cfg.log_per_steps == 0:
                tmp = np.stack(errors_list, 0)
                print('Epoch %03d, iter %07d:' % (epoch, cnt))
                print(np.mean(tmp[:, 0]), np.mean(tmp[:, 1]), np.mean(tmp[:,
                                                                          2]))
                print(np.mean(tmp[:, 3]), np.mean(tmp[:, 4]), np.mean(tmp[:,
                                                                          5]),
                      np.mean(tmp[:, 6]), np.mean(tmp[:, 7]))
                print('-------------------------')

        return np.array(errors_list)
Example #2
0
def test_syn_model(config):
    synthesizer = SynthesisModel(config)
    print(get_n_params(synthesizer))

    db = coco(config, 'train', '2017')
    syn_loader = synthesis_loader(db)
    loader = DataLoader(syn_loader,
                        batch_size=1,
                        shuffle=False,
                        num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['input_vol'].float()
        y = batched['gt_image'].float()
        z = batched['gt_label'].long()
        y = y.permute(0, 3, 1, 2)
        image, label, syn_feats, gt_feats = synthesizer(x, True, y)
        print(image.size(), label.size())
        for v in syn_feats:
            print(v.size())
        print('------------')
        for v in gt_feats:
            print(v.size())
        break
Example #3
0
def test_syn_encoder(config):
    img_encoder = SynthesisEncoder(config)
    print(get_n_params(img_encoder))

    db = coco(config, 'train', '2017')
    syn_loader = synthesis_loader(db)
    loader = DataLoader(syn_loader,
                        batch_size=1,
                        shuffle=False,
                        num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['input_vol'].float()
        y = img_encoder(x)
        for z in y:
            print(z.size())
        break
Example #4
0
def test_perceptual_loss_network(config):
    img_encoder = VGG19LossNetwork(config).eval()
    print(get_n_params(img_encoder))

    db = coco(config, 'train', '2017')
    syn_loader = synthesis_loader(db)
    loader = DataLoader(syn_loader,
                        batch_size=1,
                        shuffle=False,
                        num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['gt_image'].float()
        y = img_encoder(x.permute(0, 3, 1, 2))
        for z in y:
            print(z.size())
        break
Example #5
0
def test_syn_decoder(config):
    img_encoder = SynthesisEncoder(config)
    img_decoder = SynthesisDecoder(config)
    print(get_n_params(img_encoder))
    print(get_n_params(img_decoder))

    db = coco(config, 'train', '2017')
    syn_loader = synthesis_loader(db)
    loader = DataLoader(syn_loader,
                        batch_size=1,
                        shuffle=False,
                        num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['input_vol'].float()
        x0, x1, x2, x3, x4, x5, x6 = img_encoder(x)
        inputs = (x0, x1, x2, x3, x4, x5, x6)
        image, label = img_decoder(inputs)
        print(image.size(), label.size())
        break
def test_syn_dataloader(config):
    db = coco(config, 'train', '2017')

    syn_loader = synthesis_loader(db)
    output_dir = osp.join(config.model_dir, 'test_syn_dataloader')
    maybe_create(output_dir)

    loader = DataLoader(syn_loader,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['input_vol'].float()
        y = batched['gt_image'].float()
        z = batched['gt_label'].float()

        if config.use_color_volume:
            x = batch_color_volumn_preprocess(x, len(db.classes))
        else:
            x = batch_onehot_volumn_preprocess(x, len(db.classes))
        print('input_vol', x.size())
        print('gt_image', y.size())
        print('gt_label', z.size())

        # cv2.imwrite('mask0.png', x[0,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask1.png', x[1,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask2.png', x[2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask3.png', x[3,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('label0.png', x[0,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label1.png', x[1,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label2.png', x[2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label3.png', x[3,:,:,3].cpu().data.numpy())
        # cv2.imwrite('color0.png', x[0,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color1.png', x[1,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color2.png', x[2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color3.png', x[3,:,:,-3:].cpu().data.numpy())

        x = (x - 128.0).permute(0, 3, 1, 2)

        plt.switch_backend('agg')
        x = tensors_to_vols(x)
        for i in range(x.shape[0]):
            image_idx = batched['image_index'][i]
            name = '%03d_' % i + str(image_idx).zfill(12)
            out_path = osp.join(output_dir, name + '.png')

            if config.use_color_volume:
                proposal = x[i, :, :, 12:15]
                mask = x[i, :, :, :3]
                person = x[i, :, :, 9:12]
                other = x[i, :, :, 15:18]
                gt_color = y[i]
                gt_label = z[i]
                gt_label = np.repeat(gt_label[..., None], 3, -1)
            else:
                proposal = x[i, :, :, -3:]
                mask = x[i, :, :, -4]
                mask = np.repeat(mask[..., None], 3, -1)
                person = x[i, :, :, 3]
                person = np.repeat(person[..., None], 3, -1)
                other = x[i, :, :, 5]
                other = np.repeat(other[..., None], 3, -1)
                gt_color = y[i]
                gt_label = z[i]
                gt_label = np.repeat(gt_label[..., None], 3, -1)

            r1 = np.concatenate((proposal, mask, person), 1)
            r2 = np.concatenate((gt_color, gt_label, other), 1)
            out = np.concatenate((r1, r2), 0).astype(np.uint8)

            fig = plt.figure(figsize=(32, 32))
            plt.imshow(out[:, :, :])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        if cnt == 1:
            break
    print("Time", time() - start)
Example #7
0
    def sample_for_vis(self, epoch, test_db, N, random_or_not=False):
        ##############################################################
        # Output prefix
        ##############################################################
        plt.switch_backend('agg')
        output_dir = osp.join(self.cfg.model_dir, '%03d' % epoch, 'vis')
        maybe_create(output_dir)

        ##############################################################
        # Main loop
        ##############################################################
        syn_db = synthesis_loader(test_db)
        loader = DataLoader(syn_db,
                            batch_size=self.cfg.batch_size,
                            shuffle=random_or_not,
                            pin_memory=True)

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        max_cnt = min(N, len(test_db.scenedb))

        self.net.eval()
        for cnt, batched in enumerate(loader):
            ##################################################################
            ## Batched data
            ##################################################################
            proposals, gt_images, gt_labels = self.batch_data(batched)
            image_indices = batched['image_index'].cpu().data.numpy()

            ##################################################################
            ## Train one step
            ##################################################################
            with torch.no_grad():
                synthesized_images, synthesized_labels, _, _ = \
                    self.net(proposals, False, None)

            for i in range(synthesized_images.size(0)):
                synthesized_image = synthesized_images[i].cpu().data.numpy()
                synthesized_image = synthesized_image.transpose((1, 2, 0))
                gt_image = gt_images[i].cpu().data.numpy()

                synthesized_label = torch.max(synthesized_labels[i], 0)[-1]
                synthesized_label = synthesized_label.cpu().data.numpy()
                synthesized_label = test_db.decode_semantic_map(
                    synthesized_label)
                gt_label = gt_labels[i].cpu().data.numpy()
                gt_label = test_db.decode_semantic_map(gt_label)

                fig = plt.figure(figsize=(32, 32))
                plt.subplot(2, 2, 1)
                plt.imshow(
                    clamp_array(synthesized_image, 0, 255).astype(np.uint8))
                plt.axis('off')
                plt.subplot(2, 2, 2)
                plt.imshow(clamp_array(gt_image, 0, 255).astype(np.uint8))
                plt.axis('off')
                plt.subplot(2, 2, 3)
                plt.imshow(
                    clamp_array(synthesized_label, 0, 255).astype(np.uint8))
                plt.axis('off')
                plt.subplot(2, 2, 4)
                plt.imshow(clamp_array(gt_label, 0, 255).astype(np.uint8))
                plt.axis('off')

                image_idx = image_indices[i]
                name = '%03d_' % cnt + str(image_idx).zfill(12)
                out_path = osp.join(output_dir, name + '.png')

                fig.savefig(out_path, bbox_inches='tight')
                plt.close(fig)
                print('sampling: %d, %d, %d' % (epoch, cnt, i))

            if (cnt + 1) * self.cfg.batch_size >= max_cnt:
                break