コード例 #1
0
def load_models(files, batch_size=1):
    transformations = list(
        itertools.product(range(0, 360, 90), range(0, 360, 90)))

    size = len(transformations) * len(files)
    yield int(np.ceil(size / batch_size))

    images = np.zeros((batch_size, SIZE, SIZE, SIZE, 1), dtype="float32")
    masks = np.zeros((batch_size, SIZE, SIZE, SIZE, 1), dtype="float32")
    ip = 0
    while True:
        for image_filename, mask_filename in files:
            image = nb.load(str(image_filename)).get_fdata()
            mask = nb.load(str(mask_filename)).get_fdata()
            image = resize(image, (SIZE, SIZE, SIZE),
                           mode="constant",
                           anti_aliasing=True)
            mask = resize(mask, (SIZE, SIZE, SIZE),
                          mode="constant",
                          anti_aliasing=True)
            image = image_normalize(image)
            mask = image_normalize(mask)
            for rot1, rot2 in transformations:
                t_image = apply_transform(image, rot1, rot2)
                t_mask = apply_transform(mask, rot1, rot2)

                print(image_filename, rot1, rot2)

                images[ip] = t_image.reshape(SIZE, SIZE, SIZE, 1)
                masks[ip] = t_mask.reshape(SIZE, SIZE, SIZE, 1)
                ip += 1

                if ip == batch_size:
                    yield (images, masks)
                    ip = 0
コード例 #2
0
def load_models_patches(files,
                        transformations,
                        patch_size=SIZE,
                        batch_size=BATCH_SIZE):
    for image_filename, mask_filename in files:
        image = nb.load(str(image_filename)).get_fdata()
        mask = nb.load(str(mask_filename)).get_fdata()
        image = image_normalize(image)
        mask = image_normalize(mask)
        rot1, rot2 = random.choice(transformations)
        t_image = apply_transform(image, rot1, rot2)
        t_mask = apply_transform(mask, rot1, rot2)

        print(image_filename, mask_filename, rot1, rot2, t_image.min(),
              t_image.max(), t_mask.min(), t_mask.max())

        for sub_image, sub_mask in gen_patches(t_image, t_mask, patch_size):
            yield (sub_image, sub_mask)
コード例 #3
0
 def __init__(self, img_path, plot_res=True, to_onnx=False):
     self.device = "cuda:0"
     self.model = shufflenet()
     self.model_name = "shufflenet.pth"
     self.model.load_state_dict(torch.load(os.path.join(model_folder, "pth/{}".format(self.model_name))))
     self.model.cuda()
     self.to_onnx = to_onnx
     self.input_tensor = input_dim_3to4(image_normalize(img_path))
     self.plot = plot_res
     self.img = cv2.imread(img_path)
コード例 #4
0
def test_decoder(config):
    transformer = image_normalize('background')
    db = abstract_scene(config, 'train', transform=transformer)

    text_encoder = TextEncoder(db)
    img_encoder = ImageEncoder(config)
    what_decoder = WhatDecoder(config)
    where_decoder = WhereDecoder(config)

    # print(where_decoder)

    loader = DataLoader(db,
                        batch_size=4,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_imgs = batched['background'].float()

        encoder_states = text_encoder(word_inds, word_lens)
        bg_feats = img_encoder(bg_imgs)

        prev_bgfs = bg_feats[:, 0].unsqueeze(1)
        prev_states = {}
        prev_states['bgfs'] = prev_bgfs
        what_outs = what_decoder(prev_states, encoder_states)

        print('------------------------------------------')
        print('obj_logits', what_outs['obj_logits'].size())
        print('rnn_outs', what_outs['rnn_outs'][0].size())
        print('hids', what_outs['hids'][0].size())
        print('attn_ctx', what_outs['attn_ctx'].size())
        print('attn_wei', what_outs['attn_wei'].size())

        obj_logits = what_outs['obj_logits']
        # print('obj_logits ', obj_logits.size())
        _, obj_inds = torch.max(obj_logits + 1.0, dim=-1)
        curr_fgfs = indices2onehots(obj_inds.cpu().data,
                                    config.output_cls_size)
        # curr_fgfs = curr_fgfs.unsqueeze(1)
        if config.cuda:
            curr_fgfs = curr_fgfs.cuda()
        curr_fgfs = curr_fgfs.float()
        what_outs['fgfs'] = curr_fgfs

        where_outs = where_decoder(what_outs, encoder_states)

        print('coord_logits ', where_outs['coord_logits'].size())
        print('attri_logits ', where_outs['attri_logits'].size())
        print('attn_ctx', where_outs['attn_ctx'].size())
        print('attn_wei', where_outs['attn_wei'].size())

        break
コード例 #5
0
def main():
    image_filename = args.input
    output_filename = args.output

    image = nb.load(image_filename)
    affine = image.affine
    image = image.get_fdata()
    image = image_normalize(image)

    nn_model = model.load_model()

    if args.use_gpu:
        mask = segment_on_gpu(image, nn_model)
    else:
        mask = segment_on_cpu(image, nn_model)

    if args.ret_prob:
        save_image(image_normalize(mask, 0, 1000), output_filename, affine)
    else:
        image[mask < 0.5] = image.min()
        save_image(image_normalize(image, 0, 1000), output_filename, affine)
コード例 #6
0
def gen_train_arrays(files, patch_size=SIZE, batch_size=BATCH_SIZE):
    transformations = list(
        itertools.product(range(0, 360, 15), range(0, 360, 15)))
    files_transforms_patches = gen_all_patches(files, transformations,
                                               patch_size, batch_size,
                                               NUM_PATCHES)
    size = len(files_transforms_patches)
    print(size)
    yield int(np.ceil(size / batch_size))
    images = np.zeros(shape=(batch_size, patch_size, patch_size, patch_size,
                             1),
                      dtype=np.float32)
    masks = np.zeros_like(images)
    yield get_proportion(files_transforms_patches, patch_size)
    ip = 0
    last_filename = ""
    for image_filename, mask_filename, rot1, rot2, patch in itertools.cycle(
            files_transforms_patches):
        if image_filename != last_filename:
            image = nb.load(str(image_filename)).get_fdata()
            mask = nb.load(str(mask_filename)).get_fdata()
            image = image_normalize(image)
            mask = image_normalize(mask)
            image = apply_transform(image, rot1, rot2)
            mask = apply_transform(mask, rot1, rot2)
            last_filename = image_filename
            print(last_filename)

        images[ip] = get_image_patch(image, patch, patch_size).reshape(
            1, patch_size, patch_size, patch_size, 1)
        masks[ip] = get_image_patch(mask, patch, patch_size).reshape(
            1, patch_size, patch_size, patch_size, 1)
        ip += 1

        if ip == batch_size:
            yield images, masks
            ip = 0
コード例 #7
0
def get_proportion(files_transforms_patches, patch_size=SIZE):
    sum_bg = 0.0
    sum_fg = 0.0
    last_filename = ""
    for image_filename, mask_filename, rot1, rot2, patch in files_transforms_patches:
        if image_filename != last_filename:
            mask = nb.load(str(mask_filename)).get_fdata()
            mask = image_normalize(mask)
            mask = apply_transform(mask, rot1, rot2)
            last_filename = image_filename

        _mask = get_image_patch(mask, patch, patch_size)
        sum_bg += (_mask < 0.5).sum()
        sum_fg += (_mask >= 0.5).sum()

    return sum_bg / (sum_fg + sum_bg), sum_fg / (sum_bg + sum_fg)
コード例 #8
0
def test_img_encoder(config):
    transformer = image_normalize('background')
    db = abstract_scene(config, 'train', transform=transformer)
    img_encoder = ImageEncoder(config)
    print(get_n_params(img_encoder))

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        bg_imgs = batched['background'].float()
        y = img_encoder(bg_imgs)
        print(y.size())
        break
コード例 #9
0
    def predict(self, img_path):
        input_tensor = input_dim_3to4(image_normalize(img_path,
                                                      size=self.size))
        output = self.__test_model(input_tensor)
        # print(output)
        idx = output[0].tolist().index(max(output[0].tolist()))
        print("Predicted image is {}".format(img_path.split("\\")[-1]))
        print("Predicted index is {}".format(idx))
        print("Predicted classes is {}".format(imagenet_classes[idx]))
        print("The score is {}\n\n".format(output[0][idx]))

        if self.plot:
            img = cv2.imread(img_path)
            cv2.putText(img, self.model_name, (10, 25),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
            cv2.putText(img, imagenet_classes[idx], (10, 75),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(img, str(output[0][idx]), (10, 125),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.imshow("img", img)
            cv2.waitKey(1500)
コード例 #10
0
def test_txt_encoder_abstract(config):
    transformer = image_normalize('background')
    db = abstract_scene(config, 'train', transform=transformer)
    net = TextEncoder(db)

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        input_inds = batched['word_inds'].long()
        input_lens = batched['word_lens'].long()

        print('Checking the output shapes')
        out = net(input_inds, input_lens)
        out_rfts = out['rfts']
        out_embs = out['embs']
        out_msks = out['msks']
        out_hids = out['hids']
        print(out_rfts.size(), out_embs.size(), out_msks.size())
        if isinstance(out_hids[0], tuple):
            print(out_hids[0][0].size())
        else:
            print(out_hids[0].size())
        print('m: ', out_msks[-1])

        print('Checking the embedding')
        embeded = net.embedding(input_inds[:, 0, :])
        v1 = embeded[0, 0]
        idx = input_inds[0, 0, 0].data.item()
        v2 = db.lang_vocab.vectors[idx]
        diff = v2 - v1
        print('Diff: (should be zero)', torch.sum(diff.abs_()))

        break
コード例 #11
0
def test_topk(config):
    import os.path as osp
    from dataset import imdb
    from utils import image_normalize, maybe_create
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt

    transformer = image_normalize('background')
    db = imdb(config, split='test', transform=transformer)
    net = DrawModel(db)

    output_dir = osp.join(config.model_dir, 'test_topk')
    maybe_create(output_dir)

    pretrained_path = osp.join('data/caches/supervised_abstract.pkl')
    assert osp.exists(pretrained_path)
    if config.cuda:
        states = torch.load(pretrained_path)
    else:
        states = torch.load(pretrained_path,
                            map_location=lambda storage, loc: storage)
    net.load_state_dict(states)

    plt.switch_backend('agg')

    for i in range(len(db)):
        entry = db[i]
        scene = db.scenedb[i]

        input_inds_np = entry['word_inds']
        input_lens_np = entry['word_lens']

        input_inds = torch.from_numpy(input_inds_np).long().unsqueeze(0)
        input_lens = torch.from_numpy(input_lens_np).long().unsqueeze(0)
        if config.cuda:
            input_inds = input_inds.cuda()
            input_lens = input_lens.cuda()

        net.eval()
        with torch.no_grad():
            env = net.topk_inference(input_inds, input_lens, config.beam_size,
                                     -1)
        frames = env.batch_redraw(return_sequence=True)
        gt_img = cv2.imread(entry['color_path'], cv2.IMREAD_COLOR)
        for j in range(len(frames)):
            fig = plt.figure(figsize=(40, 20))
            title = entry['sentence']
            # title = title + '\n reward: %f, scores: %f, %f, %f, %f, %f'%(rews[j], *(scores[j]))
            plt.suptitle(title, fontsize=50)
            imgs = frames[j]
            for k in range(len(imgs)):
                plt.subplot(3, 4, k + 1)
                plt.imshow(imgs[k, :, :, ::-1])
                plt.axis('off')
            plt.subplot(3, 4, 12)
            plt.imshow(gt_img[:, :, ::-1])
            plt.axis('off')
            output_path = osp.join(output_dir, '%03d_%03d.png' % (i, j))
            fig.savefig(output_path, bbox_inches='tight')
            plt.close(fig)

        if i > 0:
            break
コード例 #12
0
def test_model(config):
    transformer = image_normalize('background')
    db = abstract_scene(config, 'val', transform=transformer)
    net = DrawModel(db)

    plt.switch_backend('agg')
    output_dir = osp.join(config.model_dir, 'test_model')
    maybe_create(output_dir)

    pretrained_path = osp.join(
        '../data/caches/abstract_ckpts/supervised_abstract_top1.pkl')
    assert osp.exists(pretrained_path)
    if config.cuda:
        states = torch.load(pretrained_path)
    else:
        states = torch.load(pretrained_path,
                            map_location=lambda storage, loc: storage)
    net.load_state_dict(states)

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    net.eval()
    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_images = batched['background'].float()

        fg_inds = batched['fg_inds'].long()
        gt_inds = batched['out_inds'].long()
        gt_msks = batched['out_msks'].float()
        hmaps = batched['hmaps'].float()

        fg_onehots = indices2onehots(fg_inds, config.output_cls_size)
        fg_onehots = fg_onehots

        inf_outs = net.teacher_forcing(word_inds, word_lens, bg_images,
                                       fg_onehots, hmaps)
        print('teacher forcing')
        print('obj_logits ', inf_outs['obj_logits'].size())
        print('coord_logits ', inf_outs['coord_logits'].size())
        print('attri_logits ', inf_outs['attri_logits'].size())
        if config.what_attn:
            print('what_att_logits ', inf_outs['what_att_logits'].size())
        if config.where_attn > 0:
            print('where_att_logits ', inf_outs['where_att_logits'].size())
        print('----------------------')
        # inf_outs, env = net(word_inds, word_lens, -1, 0, 0, gt_inds)
        inf_outs, env = net(word_inds, word_lens, 0, 1, 0, None)
        print('scheduled sampling')
        print('obj_logits ', inf_outs['obj_logits'].size())
        print('coord_logits ', inf_outs['coord_logits'].size())
        print('attri_logits ', inf_outs['attri_logits'].size())
        if config.what_attn:
            print('what_att_logits ', inf_outs['what_att_logits'].size())
        if config.where_attn > 0:
            print('where_att_logits ', inf_outs['where_att_logits'].size())
        print('----------------------')

        pred_out_inds, pred_out_msks = env.get_batch_inds_and_masks()
        print('pred_out_inds', pred_out_inds[0, 1], pred_out_inds.shape)
        print('gt_inds', gt_inds[0, 1], gt_inds.size())
        print('pred_out_msks', pred_out_msks[0, 1], pred_out_msks.shape)
        print('gt_msks', gt_msks[0, 1], gt_msks.size())

        batch_frames = env.batch_redraw(True)
        scene_inds = batched['scene_idx']
        for i in range(len(scene_inds)):
            sid = scene_inds[i]
            entry = db[sid]
            name = osp.splitext(osp.basename(entry['color_path']))[0]
            imgs = batch_frames[i]
            out_path = osp.join(output_dir, name + '.png')
            fig = plt.figure(figsize=(16, 8))
            plt.suptitle(entry['sentence'])
            for j in range(len(imgs)):
                plt.subplot(4, 3, j + 1)
                plt.imshow(imgs[j, :, :, ::-1].astype(np.uint8))
                plt.axis('off')

            target = cv2.imread(entry['color_path'], cv2.IMREAD_COLOR)
            plt.subplot(4, 3, 12)
            plt.imshow(target[:, :, ::-1])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)
        break
コード例 #13
0
def test_simulator(config):
    plt.switch_backend('agg')

    output_dir = osp.join(config.model_dir, 'simulator')
    maybe_create(output_dir)

    transformer = image_normalize('background')
    db = abstract_scene(config, 'val', transform=transformer)

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    env = simulator(db, config.batch_size)
    env.reset()

    for cnt, batched in enumerate(loader):
        out_inds = batched['out_inds'].long().numpy()
        gt_paths = batched['color_path']
        img_inds = batched['image_idx']
        sents = batched['sentence']

        sequences = []
        masks = []
        for i in range(out_inds.shape[1]):
            frames = env.batch_render_to_pytorch(out_inds[:, i])
            frames = tensor_to_img(frames)
            msks = env.batch_location_maps(out_inds[:, i, 3])
            for j in range(len(frames)):
                frame = frames[j]
                msk = cv2.resize(msks[j], (frame.shape[0], frame.shape[1]),
                                 cv2.INTER_NEAREST)
                frames[j] = frame * (1.0 - msk[..., None])
            sequences.append(frames)
        seqs1 = np.stack(sequences, 1)
        print('seqs1', seqs1.shape)
        seqs2 = env.batch_redraw(return_sequence=True)

        seqs = seqs2
        for i in range(len(seqs)):
            imgs = seqs[i]
            image_idx = img_inds[i]
            name = '%03d_' % i + str(image_idx).zfill(9)
            out_path = osp.join(output_dir, name + '.png')
            color = cv2.imread(gt_paths[i], cv2.IMREAD_COLOR)
            # color, _, _ = create_squared_image(color)

            fig = plt.figure(figsize=(32, 16))
            plt.suptitle(sents[i])

            for j in range(len(imgs)):
                plt.subplot(3, 5, j + 1)
                plt.imshow(imgs[j, :, :, ::-1])
                plt.axis('off')

            plt.subplot(3, 5, 15)
            plt.imshow(color[:, :, ::-1])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        break
コード例 #14
0
def test_evaluator(config):
    transformer = image_normalize('background')
    db = abstract_scene(config, 'train', transform=transformer)
    output_dir = osp.join(db.cfg.model_dir, 'test_evaluator')
    maybe_create(output_dir)

    ev = evaluator(db)
    for i in range(0, len(db), 2):
        # print('--------------------------------------')
        entry_1 = db[i]
        entry_2 = db[i + 1]
        scene_1 = db.scenedb[i]
        scene_2 = db.scenedb[i + 1]
        name_1 = osp.splitext(osp.basename(entry_1['color_path']))[0]
        name_2 = osp.splitext(osp.basename(entry_2['color_path']))[0]

        graph_1 = scene_graph(db, scene_1, entry_1['out_inds'], True)
        graph_2 = scene_graph(db, scene_2, entry_2['out_inds'], False)

        color_1 = cv2.imread(entry_1['color_path'], cv2.IMREAD_COLOR)
        color_2 = cv2.imread(entry_2['color_path'], cv2.IMREAD_COLOR)

        color_1 = visualize_unigram(config, color_1, graph_1.unigrams,
                                    (225, 0, 0))
        color_2 = visualize_unigram(config, color_2, graph_2.unigrams,
                                    (225, 0, 0))
        color_1 = visualize_bigram(config, color_1, graph_1.bigrams,
                                   (0, 255, 255))
        color_2 = visualize_bigram(config, color_2, graph_2.bigrams,
                                   (0, 255, 255))

        scores = ev.evaluate_graph(graph_1, graph_2)

        color_1 = visualize_unigram(config, color_1, ev.common_pred_unigrams,
                                    (0, 225, 0))
        color_2 = visualize_unigram(config, color_2, ev.common_gt_unigrams,
                                    (0, 225, 0))
        color_1 = visualize_bigram(config, color_1, ev.common_pred_bigrams,
                                   (0, 0, 255))
        color_2 = visualize_bigram(config, color_2, ev.common_gt_bigrams,
                                   (0, 0, 255))

        info = eval_info(config, scores[None, ...])

        plt.switch_backend('agg')
        fig = plt.figure(figsize=(16, 10))
        title = entry_1['sentence'] + '\n' + entry_2['sentence'] + '\n'
        title += 'unigram f3: %f, bigram f3: %f, bigram sim: %f\n' % (
            info.unigram_F3()[0], info.bigram_F3()[0], info.bigram_coord()[0])
        title += 'pose: %f, expr: %f, scale: %f, flip: %f, coord: %f \n' % (
            info.pose()[0], info.expr()[0], info.scale()[0], info.flip()[0],
            info.unigram_coord()[0])

        plt.suptitle(title)
        plt.subplot(1, 2, 1)
        plt.imshow(color_1[:, :, ::-1])
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(color_2[:, :, ::-1])
        plt.axis('off')

        out_path = osp.join(output_dir, name_1 + '_' + name_2 + '.png')
        fig.savefig(out_path, bbox_inches='tight')
        plt.close(fig)

        if i > 40:
            break
コード例 #15
0
def gen_image_patches(files, patch_size=SIZE, num_patches=NUM_PATCHES):
    image_filename, mask_filename = files
    original_image = nb.load(str(image_filename)).get_fdata()
    original_mask = nb.load(str(mask_filename)).get_fdata()

    transformations = list(
        itertools.product(range(0, 360, STEP_ROT), range(0, 360, STEP_ROT)))

    patches_files = []
    # Mirroring
    for m in range(4):
        print("m", m)
        if m == 0:
            image = original_image[:].copy()
            mask = original_mask[:].copy()
        if m == 1:
            image = original_image[::-1].copy()
            mask = original_mask[::-1].copy()
        elif m == 2:
            image = original_image[:, ::-1, :].copy()
            mask = original_mask[:, ::-1, :].copy()
        elif m == 3:
            image = original_image[:, :, ::-1].copy()
            mask = original_mask[:, :, ::-1].copy()

        for n in range(4):
            print("r", n)
            if n == 0:
                rot1, rot2 = 0, 0
            else:
                rot1 = random.randint(1, 359)
                rot2 = random.randint(1, 359)

            patches_added = 0

            _image = apply_transform(image, rot1, rot2)
            _image = image_normalize(_image)

            _mask = apply_transform(mask, rot1, rot2)
            _mask = image_normalize(_mask)

            sz, sy, sx = _image.shape
            patches = list(
                itertools.product(
                    range(0, sz, patch_size - OVERLAP),
                    range(0, sy, patch_size - OVERLAP),
                    range(0, sx, patch_size - OVERLAP),
                ))
            random.shuffle(patches)

            for patch in patches:
                sub_image = get_image_patch(_image, patch, patch_size)
                sub_mask = get_image_patch(_mask, patch, patch_size)
                if sub_mask.any():
                    tmp_filename = mktemp(suffix=".npz")
                    np.savez(tmp_filename, image=sub_image, mask=sub_mask)
                    del sub_image
                    del sub_mask
                    patches_files.append(tmp_filename)
                    patches_added += 1
                if patches_added == num_patches:
                    break

    return patches_files