def testMultiExtremalPerturbationWithSmoothMask(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") text = "How I want to say hello to Asian people" model = MMBTGridHMInterfaceOnlyImage( MMBT.from_pretrained("mmbt.hateful_memes.images"), text) model = model.to(device) image_path = "https://img.17qq.com/images/ghhngkfnkwy.jpeg" image_tensor = model.imageToTensor(image_path) # if device has some error just comment it image_tensor = image_tensor.to(device) _out, out, = multi_extremal_perturbation(model, torch.unsqueeze(image_tensor, 0), image_path, text, 0, reward_func=contrastive_reward, debug=True, max_iter=200, areas=[0.12], smooth=0.5, show_text_result=True)
def main(): image_path = input("enter your image path : ") text = input("enter your text : ") model = MMBT.from_pretrained("mmbt.hateful_memes.images") model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) image_tensor = image2tensor(image_path) mask_, hist_, output_tensor, txt_summary, text_explaination = multi_extremal_perturbation( model, image_tensor, image_path, text, 0, # 0 non hateful 1 hateful max_iter=50, areas=[0.12], ) return output_tensor, txt_summary, text_explaination