Esempio n. 1
0
def main(_config):
    _config = copy.deepcopy(_config)
    pl.seed_everything(_config["seed"])

    dm = MTDataModule(_config, dist=True)
    print(_config)

    model = ViLTransformerSS(_config)
    exp_name = f'{_config["exp_name"]}'
    os.makedirs(_config["log_dir"], exist_ok=True)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        save_top_k=1,
        verbose=True,
        monitor="val/the_metric",
        mode="max",
        save_last=True,
    )
    logger = pl.loggers.TensorBoardLogger(
        _config["log_dir"],
        name=exp_name,
    )

    lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")

    if _config['huawei_target_dir'] is not None:
        moveckpt_callback = MoveMosCKPT(_config["huawei_flag"],
                                        _config["huawei_target_dir"])
        callbacks = [checkpoint_callback, lr_callback, moveckpt_callback]
    else:
        callbacks = [checkpoint_callback, lr_callback]

    num_gpus = (_config["num_gpus"] if isinstance(_config["num_gpus"], int)
                else len(_config["num_gpus"]))
    grad_steps = _config["batch_size"] // (_config["per_gpu_batchsize"] *
                                           num_gpus * _config["num_nodes"])

    print(f'N_GPUS: {num_gpus}, grad_steps: {grad_steps}')

    max_steps = _config["max_steps"] if _config[
        "max_steps"] is not None else None

    if _config["proxy_dataset_debug"] is True:
        trainer = pl.Trainer(
            gpus=_config["num_gpus"],
            num_nodes=_config["num_nodes"],
            precision=_config["precision"],
            accelerator="ddp",
            benchmark=True,
            deterministic=True,
            max_epochs=_config["max_epoch"] if max_steps is None else 1000,
            max_steps=max_steps,
            callbacks=callbacks,
            logger=logger,
            prepare_data_per_node=False,
            replace_sampler_ddp=False,
            accumulate_grad_batches=grad_steps,
            log_every_n_steps=10,
            flush_logs_every_n_steps=10,
            progress_bar_refresh_rate=5,
            resume_from_checkpoint=_config["resume_from"],
            weights_summary="top",
            fast_dev_run=_config["fast_dev_run"],
            val_check_interval=_config["val_check_interval"],
            limit_train_batches=1,
            limit_val_batches=1)
    else:
        trainer = pl.Trainer(
            gpus=_config["num_gpus"],
            num_nodes=_config["num_nodes"],
            precision=_config["precision"],
            accelerator="ddp",
            benchmark=True,
            deterministic=True,
            max_epochs=_config["max_epoch"] if max_steps is None else 1000,
            max_steps=max_steps,
            callbacks=callbacks,
            logger=logger,
            prepare_data_per_node=False,
            replace_sampler_ddp=False,
            accumulate_grad_batches=grad_steps,
            log_every_n_steps=10,
            flush_logs_every_n_steps=10,
            progress_bar_refresh_rate=50,
            resume_from_checkpoint=_config["resume_from"],
            weights_summary="top",
            fast_dev_run=_config["fast_dev_run"],
            val_check_interval=_config["val_check_interval"],
        )

    if not _config["test_only"]:
        trainer.fit(model, datamodule=dm)
    else:
        trainer.test(model, datamodule=dm)
Esempio n. 2
0
def main(_config):
    _config = copy.deepcopy(_config)

    loss_names = {
        "itm": 0,
        "mlm": 0,
        "mpp": 0,
        "vqa": 1,
        "imgcls": 0,
        "nlvr2": 0,
        "irtr": 0,
        "arc": 0,
    }
    tokenizer = get_pretrained_tokenizer(_config["tokenizer"])

    with urllib.request.urlopen(
            "https://dl.dropboxusercontent.com/s/otya4i5sagt4f5p/vqa_dict.json"
    ) as url:
        id2ans = json.loads(url.read().decode())

    _config.update({
        "loss_names": loss_names,
    })

    model = ViLTransformerSS(_config)
    model.setup("test")
    model.eval()

    device = "cuda:0" if _config["num_gpus"] > 0 else "cpu"
    model.to(device)

    def infer(url, text):
        try:
            res = requests.get(url)
            image = Image.open(io.BytesIO(res.content)).convert("RGB")
            img = pixelbert_transform(size=384)(image)
            img = img.unsqueeze(0).to(device)
        except:
            return False

        batch = {"text": [text], "image": [img]}

        with torch.no_grad():
            encoded = tokenizer(batch['text'])
            batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
            batch["text_labels"] = torch.tensor(
                encoded["input_ids"]).to(device)
            batch["text_masks"] = torch.tensor(
                encoded["attention_mask"]).to(device)
            infer = model.infer(batch)
            vqa_logits = model.vqa_classifier(infer['cls_feats'])

        answer = id2ans[str(vqa_logits.argmax().item())]

        return [np.array(image), answer]

    inputs = [
        gr.inputs.Textbox(
            label="Url of an image.",
            lines=5,
        ),
        gr.inputs.Textbox(label="Question", lines=5),
    ]
    outputs = [
        gr.outputs.Image(label="Image"),
        gr.outputs.Textbox(label="Answer"),
    ]

    interface = gr.Interface(
        fn=infer,
        inputs=inputs,
        outputs=outputs,
        server_name="0.0.0.0",
        server_port=8888,
        examples=[
            [
                "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
                "What is the color of the flower?",
            ],
            [
                "https://computing.ece.vt.edu/~harsh/visualAttention/ProjectWebpage/Figures/vqa_1.png",
                "What is the mustache made of?",
            ],
            [
                "https://computing.ece.vt.edu/~harsh/visualAttention/ProjectWebpage/Figures/vqa_2.png",
                "How many slices of pizza are there?",
            ],
            [
                "https://computing.ece.vt.edu/~harsh/visualAttention/ProjectWebpage/Figures/vqa_3.png",
                "Does it appear to be rainy?",
            ],
        ],
    )

    interface.launch(debug=True)
Esempio n. 3
0
def main(_config):
    _config = copy.deepcopy(_config)
    pl.seed_everything(_config["seed"])

    dm = MTDataModule(_config, dist=True)

    model = ViLTransformerSS(_config)
    exp_name = f'{_config["exp_name"]}'

    os.makedirs(_config["log_dir"], exist_ok=True)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        save_top_k=1,
        verbose=True,
        monitor="val/the_metric",
        mode="max",
        save_last=True,
    )
    logger = pl.loggers.TensorBoardLogger(
        _config["log_dir"],
        name=
        f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
    )

    lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
    callbacks = [checkpoint_callback, lr_callback]

    num_gpus = (_config["num_gpus"] if isinstance(_config["num_gpus"], int)
                else len(_config["num_gpus"]))

    grad_steps = _config["batch_size"] // (_config["per_gpu_batchsize"] *
                                           num_gpus * _config["num_nodes"])

    max_steps = _config["max_steps"] if _config[
        "max_steps"] is not None else None

    trainer = pl.Trainer(
        gpus=_config["num_gpus"],
        num_nodes=_config["num_nodes"],
        precision=_config["precision"],
        accelerator="ddp",
        benchmark=True,
        deterministic=True,
        max_epochs=_config["max_epoch"] if max_steps is None else 1000,
        max_steps=max_steps,
        callbacks=callbacks,
        logger=logger,
        prepare_data_per_node=False,
        replace_sampler_ddp=False,
        accumulate_grad_batches=grad_steps,
        log_every_n_steps=10,
        flush_logs_every_n_steps=10,
        resume_from_checkpoint=_config["resume_from"],
        weights_summary="top",
        fast_dev_run=_config["fast_dev_run"],
        val_check_interval=_config["val_check_interval"],
    )

    if not _config["test_only"]:
        trainer.fit(model, datamodule=dm)
    else:
        trainer.test(model, datamodule=dm)
def main(_config):
    _config = copy.deepcopy(_config)

    loss_names = {
        "itm": 0,
        "mlm": 0.5,
        "mpp": 0,
        "vqa": 0,
        "imgcls": 0,
        "nlvr2": 0,
        "irtr": 0,
        "arc": 0,
    }
    tokenizer = get_pretrained_tokenizer(_config["tokenizer"])

    _config.update({
        "loss_names": loss_names,
    })

    model = ViLTransformerSS(_config)
    model.setup("test")
    model.eval()

    device = "cuda:0" if _config["num_gpus"] > 0 else "cpu"
    model.to(device)

    def infer(url, mp_text, hidx):
        try:
            res = requests.get(url)
            image = Image.open(io.BytesIO(res.content)).convert("RGB")
            img = pixelbert_transform(size=384)(image)
            img = img.unsqueeze(0).to(device)
        except:
            return False

        batch = {"text": [""], "image": [None]}
        tl = len(re.findall("\[MASK\]", mp_text))
        inferred_token = [mp_text]
        batch["image"][0] = img

        with torch.no_grad():
            for i in range(tl):
                batch["text"] = inferred_token
                encoded = tokenizer(inferred_token)
                batch["text_ids"] = torch.tensor(
                    encoded["input_ids"]).to(device)
                batch["text_labels"] = torch.tensor(
                    encoded["input_ids"]).to(device)
                batch["text_masks"] = torch.tensor(
                    encoded["attention_mask"]).to(device)
                encoded = encoded["input_ids"][0][1:-1]
                infer = model(batch)
                mlm_logits = model.mlm_score(infer["text_feats"])[0, 1:-1]
                mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
                mlm_values[torch.tensor(encoded) != 103] = 0
                select = mlm_values.argmax().item()
                encoded[select] = mlm_ids[select].item()
                inferred_token = [tokenizer.decode(encoded)]

        selected_token = ""
        encoded = tokenizer(inferred_token)

        if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]):
            with torch.no_grad():
                batch["text"] = inferred_token
                batch["text_ids"] = torch.tensor(
                    encoded["input_ids"]).to(device)
                batch["text_labels"] = torch.tensor(
                    encoded["input_ids"]).to(device)
                batch["text_masks"] = torch.tensor(
                    encoded["attention_mask"]).to(device)
                infer = model(batch)
                txt_emb, img_emb = infer["text_feats"], infer["image_feats"]
                txt_mask, img_mask = (
                    infer["text_masks"].bool(),
                    infer["image_masks"].bool(),
                )
                for i, _len in enumerate(txt_mask.sum(dim=1)):
                    txt_mask[i, _len - 1] = False
                txt_mask[:, 0] = False
                img_mask[:, 0] = False
                txt_pad, img_pad = ~txt_mask, ~img_mask

                cost = cost_matrix_cosine(txt_emb.float(), img_emb.float())
                joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2)
                cost.masked_fill_(joint_pad, 0)

                txt_len = (txt_pad.size(1) -
                           txt_pad.sum(dim=1, keepdim=False)).to(
                               dtype=cost.dtype)
                img_len = (img_pad.size(1) -
                           img_pad.sum(dim=1, keepdim=False)).to(
                               dtype=cost.dtype)
                T = ipot(
                    cost.detach(),
                    txt_len,
                    txt_pad,
                    img_len,
                    img_pad,
                    joint_pad,
                    0.1,
                    1000,
                    1,
                )

                plan = T[0]
                plan_single = plan * len(txt_emb)
                cost_ = plan_single.t()

                cost_ = cost_[hidx][1:].cpu()

                patch_index, (H, W) = infer["patch_index"]
                heatmap = torch.zeros(H, W)
                for i, pidx in enumerate(patch_index[0]):
                    h, w = pidx[0].item(), pidx[1].item()
                    heatmap[h, w] = cost_[i]

                heatmap = (heatmap - heatmap.mean()) / heatmap.std()
                heatmap = np.clip(heatmap, 1.0, 3.0)
                heatmap = (heatmap - heatmap.min()) / (heatmap.max() -
                                                       heatmap.min())

                _w, _h = image.size
                overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize(
                    (_w, _h), resample=Image.NEAREST)
                image_rgba = image.copy()
                image_rgba.putalpha(overlay)
                image = image_rgba

                selected_token = tokenizer.convert_ids_to_tokens(
                    encoded["input_ids"][0][hidx])

        return [np.array(image), inferred_token[0], selected_token]

    inputs = [
        gr.inputs.Textbox(
            label="Url of an image.",
            lines=5,
        ),
        gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.",
                          lines=5),
        gr.inputs.Slider(
            minimum=0,
            maximum=38,
            step=1,
            label="Index of token for heatmap visualization (ignored if zero)",
        ),
    ]
    outputs = [
        gr.outputs.Image(label="Image"),
        gr.outputs.Textbox(label="description"),
        gr.outputs.Textbox(label="selected token"),
    ]

    interface = gr.Interface(
        fn=infer,
        inputs=inputs,
        outputs=outputs,
        server_name="0.0.0.0",
        server_port=8888,
        examples=[
            [
                "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
                "a display of flowers growing out and over the [MASK] [MASK] in front of [MASK] on a [MASK] [MASK].",
                0,
            ],
            [
                "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
                "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
                4,
            ],
            [
                "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
                "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
                11,
            ],
            [
                "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
                "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
                15,
            ],
            [
                "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
                "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
                18,
            ],
            [
                "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
                "a room with a [MASK], a [MASK], a [MASK], and a [MASK].",
                0,
            ],
            [
                "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
                "a room with a rug, a chair, a painting, and a plant.",
                5,
            ],
            [
                "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
                "a room with a rug, a chair, a painting, and a plant.",
                8,
            ],
            [
                "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
                "a room with a rug, a chair, a painting, and a plant.",
                11,
            ],
            [
                "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
                "a room with a rug, a chair, a painting, and a plant.",
                15,
            ],
        ],
    )

    interface.launch(debug=True)