def main(imgurl):
    # Load word map (word2ix)
    with open('input_files/WORDMAP.json', 'r') as j:
        word_map = json.load(j)
    rev_word_map = {v: k for k, v in word_map.items()}  # ix2word

    # Load model
    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=len(word_map),
                                   dropout=dropout)
    decoder_optimizer = torch.optim.Adam(params=filter(
        lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(
        params=filter(lambda p: p.requires_grad, encoder.parameters()),
        lr=encoder_lr) if fine_tune_encoder else None

    decoder.load_state_dict(
        torch.load('output_files/BEST_checkpoint_decoder.pth.tar'))
    encoder.load_state_dict(
        torch.load('output_files/BEST_checkpoint_encoder.pth.tar'))

    decoder = decoder.to(device)
    decoder.eval()
    encoder = encoder.to(device)
    encoder.eval()

    # Encode, decode with attention and beam search
    seq, alphas = caption_image_beam_search(encoder,
                                            decoder,
                                            imgurl,
                                            word_map,
                                            beam_size=5)
    alphas = torch.FloatTensor(alphas)

    # Visualize caption and attention of best sequence
    # visualize_att(img, seq, alphas, rev_word_map, args.smooth)

    words = [rev_word_map[ind] for ind in seq]
    caption = ' '.join(words[1:-1])
    visualize_att(imgurl, seq, alphas, rev_word_map)
# Read word map
word_map_file = os.path.join('/scratch/scratch2/adsue/caption_dataset',
                             'WORDMAP_' + data_name + '.json')
with open(word_map_file, 'r') as j:
    word_map = json.load(j)

decoder = DecoderWithAttention(attention_dim=attention_dim,
                               embed_dim=emb_dim,
                               decoder_dim=decoder_dim,
                               vocab_size=len(word_map),
                               dropout=dropout)

decoder.load_state_dict(
    torch.load('/scratch/scratch2/adsue/pretrained/decoder_dict.pkl'))
decoder = decoder.to(device)
decoder.eval()

encoder = Encoder()
encoder.load_state_dict(
    torch.load('/scratch/scratch2/adsue/pretrained/encoder_dict.pkl'))
encoder = encoder.to(device)
encoder.eval()
##########################################################################################################################

imsize = 256
image_transform = transforms.Compose(
    [transforms.Scale(int(imsize * 76 / 64)),
     transforms.RandomCrop(imsize)])

norm = transforms.Compose([
    transforms.ToTensor(),