def main(_config): _config = copy.deepcopy(_config) pl.seed_everything(_config["seed"]) dm = MTDataModule(_config, dist=True) print(_config) model = ViLTransformerSS(_config) exp_name = f'{_config["exp_name"]}' os.makedirs(_config["log_dir"], exist_ok=True) checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=1, verbose=True, monitor="val/the_metric", mode="max", save_last=True, ) logger = pl.loggers.TensorBoardLogger( _config["log_dir"], name=exp_name, ) lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") if _config['huawei_target_dir'] is not None: moveckpt_callback = MoveMosCKPT(_config["huawei_flag"], _config["huawei_target_dir"]) callbacks = [checkpoint_callback, lr_callback, moveckpt_callback] else: callbacks = [checkpoint_callback, lr_callback] num_gpus = (_config["num_gpus"] if isinstance(_config["num_gpus"], int) else len(_config["num_gpus"])) grad_steps = _config["batch_size"] // (_config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]) print(f'N_GPUS: {num_gpus}, grad_steps: {grad_steps}') max_steps = _config["max_steps"] if _config[ "max_steps"] is not None else None if _config["proxy_dataset_debug"] is True: trainer = pl.Trainer( gpus=_config["num_gpus"], num_nodes=_config["num_nodes"], precision=_config["precision"], accelerator="ddp", benchmark=True, deterministic=True, max_epochs=_config["max_epoch"] if max_steps is None else 1000, max_steps=max_steps, callbacks=callbacks, logger=logger, prepare_data_per_node=False, replace_sampler_ddp=False, accumulate_grad_batches=grad_steps, log_every_n_steps=10, flush_logs_every_n_steps=10, progress_bar_refresh_rate=5, resume_from_checkpoint=_config["resume_from"], weights_summary="top", fast_dev_run=_config["fast_dev_run"], val_check_interval=_config["val_check_interval"], limit_train_batches=1, limit_val_batches=1) else: trainer = pl.Trainer( gpus=_config["num_gpus"], num_nodes=_config["num_nodes"], precision=_config["precision"], accelerator="ddp", benchmark=True, deterministic=True, max_epochs=_config["max_epoch"] if max_steps is None else 1000, max_steps=max_steps, callbacks=callbacks, logger=logger, prepare_data_per_node=False, replace_sampler_ddp=False, accumulate_grad_batches=grad_steps, log_every_n_steps=10, flush_logs_every_n_steps=10, progress_bar_refresh_rate=50, resume_from_checkpoint=_config["resume_from"], weights_summary="top", fast_dev_run=_config["fast_dev_run"], val_check_interval=_config["val_check_interval"], ) if not _config["test_only"]: trainer.fit(model, datamodule=dm) else: trainer.test(model, datamodule=dm)
def main(_config): _config = copy.deepcopy(_config) loss_names = { "itm": 0, "mlm": 0, "mpp": 0, "vqa": 1, "imgcls": 0, "nlvr2": 0, "irtr": 0, "arc": 0, } tokenizer = get_pretrained_tokenizer(_config["tokenizer"]) with urllib.request.urlopen( "https://dl.dropboxusercontent.com/s/otya4i5sagt4f5p/vqa_dict.json" ) as url: id2ans = json.loads(url.read().decode()) _config.update({ "loss_names": loss_names, }) model = ViLTransformerSS(_config) model.setup("test") model.eval() device = "cuda:0" if _config["num_gpus"] > 0 else "cpu" model.to(device) def infer(url, text): try: res = requests.get(url) image = Image.open(io.BytesIO(res.content)).convert("RGB") img = pixelbert_transform(size=384)(image) img = img.unsqueeze(0).to(device) except: return False batch = {"text": [text], "image": [img]} with torch.no_grad(): encoded = tokenizer(batch['text']) batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device) batch["text_labels"] = torch.tensor( encoded["input_ids"]).to(device) batch["text_masks"] = torch.tensor( encoded["attention_mask"]).to(device) infer = model.infer(batch) vqa_logits = model.vqa_classifier(infer['cls_feats']) answer = id2ans[str(vqa_logits.argmax().item())] return [np.array(image), answer] inputs = [ gr.inputs.Textbox( label="Url of an image.", lines=5, ), gr.inputs.Textbox(label="Question", lines=5), ] outputs = [ gr.outputs.Image(label="Image"), gr.outputs.Textbox(label="Answer"), ] interface = gr.Interface( fn=infer, inputs=inputs, outputs=outputs, server_name="0.0.0.0", server_port=8888, examples=[ [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "What is the color of the flower?", ], [ "https://computing.ece.vt.edu/~harsh/visualAttention/ProjectWebpage/Figures/vqa_1.png", "What is the mustache made of?", ], [ "https://computing.ece.vt.edu/~harsh/visualAttention/ProjectWebpage/Figures/vqa_2.png", "How many slices of pizza are there?", ], [ "https://computing.ece.vt.edu/~harsh/visualAttention/ProjectWebpage/Figures/vqa_3.png", "Does it appear to be rainy?", ], ], ) interface.launch(debug=True)
def main(_config): _config = copy.deepcopy(_config) pl.seed_everything(_config["seed"]) dm = MTDataModule(_config, dist=True) model = ViLTransformerSS(_config) exp_name = f'{_config["exp_name"]}' os.makedirs(_config["log_dir"], exist_ok=True) checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=1, verbose=True, monitor="val/the_metric", mode="max", save_last=True, ) logger = pl.loggers.TensorBoardLogger( _config["log_dir"], name= f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', ) lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") callbacks = [checkpoint_callback, lr_callback] num_gpus = (_config["num_gpus"] if isinstance(_config["num_gpus"], int) else len(_config["num_gpus"])) grad_steps = _config["batch_size"] // (_config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]) max_steps = _config["max_steps"] if _config[ "max_steps"] is not None else None trainer = pl.Trainer( gpus=_config["num_gpus"], num_nodes=_config["num_nodes"], precision=_config["precision"], accelerator="ddp", benchmark=True, deterministic=True, max_epochs=_config["max_epoch"] if max_steps is None else 1000, max_steps=max_steps, callbacks=callbacks, logger=logger, prepare_data_per_node=False, replace_sampler_ddp=False, accumulate_grad_batches=grad_steps, log_every_n_steps=10, flush_logs_every_n_steps=10, resume_from_checkpoint=_config["resume_from"], weights_summary="top", fast_dev_run=_config["fast_dev_run"], val_check_interval=_config["val_check_interval"], ) if not _config["test_only"]: trainer.fit(model, datamodule=dm) else: trainer.test(model, datamodule=dm)
def main(_config): _config = copy.deepcopy(_config) loss_names = { "itm": 0, "mlm": 0.5, "mpp": 0, "vqa": 0, "imgcls": 0, "nlvr2": 0, "irtr": 0, "arc": 0, } tokenizer = get_pretrained_tokenizer(_config["tokenizer"]) _config.update({ "loss_names": loss_names, }) model = ViLTransformerSS(_config) model.setup("test") model.eval() device = "cuda:0" if _config["num_gpus"] > 0 else "cpu" model.to(device) def infer(url, mp_text, hidx): try: res = requests.get(url) image = Image.open(io.BytesIO(res.content)).convert("RGB") img = pixelbert_transform(size=384)(image) img = img.unsqueeze(0).to(device) except: return False batch = {"text": [""], "image": [None]} tl = len(re.findall("\[MASK\]", mp_text)) inferred_token = [mp_text] batch["image"][0] = img with torch.no_grad(): for i in range(tl): batch["text"] = inferred_token encoded = tokenizer(inferred_token) batch["text_ids"] = torch.tensor( encoded["input_ids"]).to(device) batch["text_labels"] = torch.tensor( encoded["input_ids"]).to(device) batch["text_masks"] = torch.tensor( encoded["attention_mask"]).to(device) encoded = encoded["input_ids"][0][1:-1] infer = model(batch) mlm_logits = model.mlm_score(infer["text_feats"])[0, 1:-1] mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1) mlm_values[torch.tensor(encoded) != 103] = 0 select = mlm_values.argmax().item() encoded[select] = mlm_ids[select].item() inferred_token = [tokenizer.decode(encoded)] selected_token = "" encoded = tokenizer(inferred_token) if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]): with torch.no_grad(): batch["text"] = inferred_token batch["text_ids"] = torch.tensor( encoded["input_ids"]).to(device) batch["text_labels"] = torch.tensor( encoded["input_ids"]).to(device) batch["text_masks"] = torch.tensor( encoded["attention_mask"]).to(device) infer = model(batch) txt_emb, img_emb = infer["text_feats"], infer["image_feats"] txt_mask, img_mask = ( infer["text_masks"].bool(), infer["image_masks"].bool(), ) for i, _len in enumerate(txt_mask.sum(dim=1)): txt_mask[i, _len - 1] = False txt_mask[:, 0] = False img_mask[:, 0] = False txt_pad, img_pad = ~txt_mask, ~img_mask cost = cost_matrix_cosine(txt_emb.float(), img_emb.float()) joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) cost.masked_fill_(joint_pad, 0) txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to( dtype=cost.dtype) img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to( dtype=cost.dtype) T = ipot( cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad, 0.1, 1000, 1, ) plan = T[0] plan_single = plan * len(txt_emb) cost_ = plan_single.t() cost_ = cost_[hidx][1:].cpu() patch_index, (H, W) = infer["patch_index"] heatmap = torch.zeros(H, W) for i, pidx in enumerate(patch_index[0]): h, w = pidx[0].item(), pidx[1].item() heatmap[h, w] = cost_[i] heatmap = (heatmap - heatmap.mean()) / heatmap.std() heatmap = np.clip(heatmap, 1.0, 3.0) heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) _w, _h = image.size overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize( (_w, _h), resample=Image.NEAREST) image_rgba = image.copy() image_rgba.putalpha(overlay) image = image_rgba selected_token = tokenizer.convert_ids_to_tokens( encoded["input_ids"][0][hidx]) return [np.array(image), inferred_token[0], selected_token] inputs = [ gr.inputs.Textbox( label="Url of an image.", lines=5, ), gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5), gr.inputs.Slider( minimum=0, maximum=38, step=1, label="Index of token for heatmap visualization (ignored if zero)", ), ] outputs = [ gr.outputs.Image(label="Image"), gr.outputs.Textbox(label="description"), gr.outputs.Textbox(label="selected token"), ] interface = gr.Interface( fn=infer, inputs=inputs, outputs=outputs, server_name="0.0.0.0", server_port=8888, examples=[ [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the [MASK] [MASK] in front of [MASK] on a [MASK] [MASK].", 0, ], [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.", 4, ], [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.", 11, ], [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.", 15, ], [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.", 18, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a [MASK], a [MASK], a [MASK], and a [MASK].", 0, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a rug, a chair, a painting, and a plant.", 5, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a rug, a chair, a painting, and a plant.", 8, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a rug, a chair, a painting, and a plant.", 11, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a rug, a chair, a painting, and a plant.", 15, ], ], ) interface.launch(debug=True)