Ejemplo n.º 1
0
def inference_coco(encoder_file: str, decoder_file: str, embed_size: int,
                   hidden_size: int, from_cpu: bool) -> None:
    """
    Displays an original image from coco test dataset and prints its associated caption.

    encoder_file:   Name of the encoder to load.
    decoder_file:   Name of the decoder to load.
    embed_size:     Word embedding size for the encoder.
    hidden_size:    Hidden layer of the LSTM size.
    from_cpu:       Whether the model has been saved on CPU.
    """
    # Define transform
    transform_test = transforms.Compose([
        transforms.Resize(256),  # smaller edge of image resized to 256
        transforms.RandomCrop(224),  # get 224x224 crop from random location
        transforms.ToTensor(),  # convert the PIL Image to a tensor
        transforms.Normalize(
            (0.485, 0.456, 0.406),  # normalize image for pre-trained model
            (0.229, 0.224, 0.225))
    ])

    # Device to use fo inference
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the data loader.
    data_loader = get_loader(transform=transform_test, mode='test')

    # Obtain sample image
    _, image = next(iter(data_loader))

    # The size of the vocabulary.
    vocab_size = len(data_loader.dataset.vocab)

    # Initialize the encoder and decoder, and set each to inference mode.
    encoder = EncoderCNN(embed_size)
    encoder.eval()
    decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
    decoder.eval()

    # Load the trained weights.
    if from_cpu:
        encoder.load_state_dict(
            torch.load(os.path.join('./models', encoder_file),
                       map_location='cpu'))
        decoder.load_state_dict(
            torch.load(os.path.join('./models', decoder_file),
                       map_location='cpu'))
    else:
        encoder.load_state_dict(
            torch.load(os.path.join('./models', encoder_file)))
        decoder.load_state_dict(
            torch.load(os.path.join('./models', decoder_file)))

    # Move models to GPU if CUDA is available.
    encoder.to(device)
    decoder.to(device)

    get_prediction(encoder, decoder, data_loader, device)
Ejemplo n.º 2
0
def main():
    st.title('Image Captioning App')
    st.markdown(STYLE, unsafe_allow_html=True)

    file = st.file_uploader("Upload file", type=["png", "jpg", "jpeg"])
    show_file = st.empty()

    if not file:
        show_file.info("Please upload a file of type: " +
                       ", ".join(["png", "jpg", "jpeg"]))
        return

    content = file.getvalue()

    show_file.image(file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder_file = 'encoder-5-batch-128-hidden-256-epochs-5.pkl'
    decoder_file = 'decoder-5-batch-128-hidden-256-epochs-5.pkl'

    embed_size = 300
    hidden_size = 256

    vocab_size, word2idx, idx2word = get_vocab()

    encoder = EncoderCNN(embed_size)
    encoder.eval()
    decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
    decoder.eval()

    encoder.load_state_dict(torch.load(os.path.join('./models', encoder_file)))
    decoder.load_state_dict(torch.load(os.path.join('./models', decoder_file)))

    encoder.to(device)
    decoder.to(device)

    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    PIL_image = Image.open(file).convert('RGB')
    orig_image = np.array(PIL_image)
    image = transform_test(PIL_image)
    image = image.to(device).unsqueeze(0)
    features = encoder(image).unsqueeze(1)
    output = decoder.sample(features)

    sentence = clean_sentence(output, idx2word)
    st.info("Generated caption --> " + sentence)

    file.close()
Ejemplo n.º 3
0
        def generatecaption(image):
            # Image preprocessing
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))])

            # Load vocabulary wrapper
            with open('/root/ImageCaptioning/data/vocab.pkl', 'rb') as f:
                vocab = pickle.load(f)

            # Build models
            encoder = EncoderCNN(256).eval()  # eval mode (batchnorm uses moving mean/variance)
            decoder = DecoderRNN(256, 512, len(vocab), 1)
            encoder = encoder.to(device)
            decoder = decoder.to(device)

            # Load the trained model parameters
            encoder.load_state_dict(torch.load('models/encoder-5-3000.pkl', map_location='cpu'))
            decoder.load_state_dict(torch.load('models/decoder-5-3000.pkl', map_location='cpu'))

            encoder.eval()
            decoder.eval()
            # Prepare an image
            image = load_image(image, transform)
            image_tensor = image.to(device)

            # Generate an caption from the image
            feature = encoder(image_tensor)
            sampled_ids = decoder.sample(feature)
            sampled_ids = sampled_ids[0].cpu().numpy()          # (1, max_seq_length) -> (max_seq_length)

            # Convert word_ids to words
            sampled_caption = []
            for word_id in sampled_ids:
                word = vocab.idx2word[word_id]
                sampled_caption.append(word)
                if word == '<end>':
                    break
            self.sentence = ' '.join(sampled_caption)

            # Print out the image and the generated caption


            self.Entry1.delete(0, END)
            self.Entry1.insert(0,self.sentence[7:-5])
Ejemplo n.º 4
0
def get_text_caption(image):

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build models
    encoder = EncoderCNN(
        args.embed_size, args.model_type,
        args.mode)  # eval mode (batchnorm uses moving mean/variance)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # Load the trained model parameters
    encoder.load_state_dict(
        torch.load(args.model_path + "_" + args.model_type + "/encoder.pt"))
    encoder.eval()
    decoder.load_state_dict(
        torch.load(args.model_path + "_" + args.model_type + "/decoder.pt"))
    decoder.eval()

    # Prepare an image
    image_tensor = image.to(device)

    # Generate an caption from the image
    feature = encoder(image_tensor)

    sampled_ids = decoder.sample(feature)
    sampled_ids = sampled_ids[0].cpu().numpy(
    )  # (1, max_seq_length) -> (max_seq_length)
    print(sampled_ids)

    # Convert word_ids to words
    sampled_caption = []
    for word_id in sampled_ids:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    sentence = ' '.join(sampled_caption)

    return (sentence.split("<start> ")[1].split(" <end>")[0]
            [:-2].capitalize().replace(" , ", ", "))
Ejemplo n.º 5
0
class Annotator():
    def __init__(self):
        self.transform = transforms.Compose([ 
            transforms.Resize(256),                          # smaller edge of image resized to 256
            transforms.CenterCrop(224),                      # get 224x224 crop from the center
            transforms.ToTensor(),                           # convert the PIL Image to a tensor
            transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                                 (0.229, 0.224, 0.225))])
        
        # Load cherckpoint with best model
        self.checkpoint = torch.load(os.path.join('./models', 'best-model.pkl'), 'cpu')
        # Specify values for embed_size and hidden_size - we use the same values as in training step
        self.embed_size = 512
        self.hidden_size = 512

        # Get the vocabulary and its size
        self.vocab = Vocabulary(None, './vocab.pkl', "<start>", "<end>", "<unk>", "<pad>", "", "", True)
        self.vocab_size = len(self.vocab)

        # Initialize the encoder and decoder, and set each to inference mode
        self.encoder = EncoderCNN(self.embed_size)
        self.encoder.eval()
        self.decoder = DecoderRNN(self.embed_size, self.hidden_size, self.vocab_size)
        self.decoder.eval()

        # Load the pre-trained weights
        self.encoder.load_state_dict(self.checkpoint['encoder'])
        self.decoder.load_state_dict(self.checkpoint['decoder'])

        # Move models to GPU if CUDA is available.
        #if torch.cuda.is_available():
        #   encoder.cuda()
        #   decoder.cuda()

    def annotate(self, image):
        transformed = self.transform(image).unsqueeze(0)
        features = self.encoder(transformed).unsqueeze(1)

        # Pass the embedded image features through the model to get a predicted caption.
        output = self.decoder.sample_beam_search(features)
        print('example output:', output)
        sentence = clean_sentence(output[0], self.vocab)
        print('example sentence:', sentence)
        return sentence
Ejemplo n.º 6
0
    def get_caption(self, img_tensor):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(device)
        print("running")

        # Models
        encoder_file = 'legit_model/encoder_1.pkl'
        decoder_file = 'legit_model/decoder_1.pkl'

        # Embed and hidden
        embed_size = 512
        hidden_size = 512

        # The size of the vocabulary.
        vocab_size = 8856

        # Initialize the encoder and decoder, and set each to inference mode.
        encoder = EncoderCNN(embed_size)
        encoder.eval()

        decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
        decoder.eval()

        # Load the trained weights.
        encoder.load_state_dict(
            torch.load(os.path.join('./models', encoder_file)))
        decoder.load_state_dict(
            torch.load(os.path.join('./models', decoder_file)))

        # Move models to GPU if CUDA is available.
        encoder.to(device)
        decoder.to(device)

        img_d = img_tensor.to(device)

        # Obtain the embedded image features.
        features = encoder(img_d).unsqueeze(1)

        # Pass the embedded image features through the model to get a predicted caption.
        img_output = decoder.sample(features)

        sentence = self.clean_sentence(img_output)

        return sentence
Ejemplo n.º 7
0
def initialize():
    checkpoint = torch.load(os.path.join('./models', 'best-model.pkl'), map_location=torch.device('cpu'))

    embed_size = 256
    hidden_size = 512

    with open('./vocab.pkl', "rb") as f:
        vocab = pickle.load(f)
    vocab_size = len(vocab)

    encoder = EncoderCNN(embed_size)
    decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
    encoder.eval()
    decoder.eval()

    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])

    return encoder, decoder, vocab
Ejemplo n.º 8
0
def evaluate(encoder_model_path, decoder_model_path):
    transformation = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    VAL_LOADER_UNIQUE = {
        'root': config.VAL_IMG_PATH,
        'json': config.VAL_JSON_PATH,
        'batch_size': 16,
        'shuffle': False,
        'transform': transformation,
        'num_workers': 4
    }
    val_loader_unique = get_loader_unique(**VAL_LOADER_UNIQUE)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = CNNfull(2048)
    encoder.to(device)
    decoder = DecoderRNN(2048, 300, 512, vocab_size)
    decoder.to(device)
    encoder.load_state_dict(torch.load(encoder_model_path))
    decoder.load_state_dict(torch.load(decoder_model_path))
    encoder.eval()
    decoder.eval()

    bleu2, bleu3, bleu4, meteor = val_epoch(val_loader_unique,
                                            device,
                                            encoder,
                                            decoder,
                                            vocab,
                                            0,
                                            enc_scheduler=None,
                                            dec_scheduler=None,
                                            view_val_captions=False)
    print(f'Bleu2 score:{bleu2}')
    print(f'Bleu3 score:{bleu3}')
    print(f'Bleu4 score:{bleu4}')
    print(f'Meteor score:{meteor}')
Ejemplo n.º 9
0
def get_model(device,vocab_size):
    # model weights file
    encoder_file = "models/encoder-3.pkl" 
    decoder_file = "models/decoder-3.pkl"

    embed_size = 512
    hidden_size = 512

    # Initialize the encoder and decoder, and set each to inference mode.
    encoder = EncoderCNN(embed_size)
    encoder.eval()
    decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
    decoder.eval()

    # Load the trained weights.
    #print(torch.load(encoder_file))
    encoder.load_state_dict(torch.load(encoder_file))
    decoder.load_state_dict(torch.load(decoder_file))

    # Move models to GPU if CUDA is available.
    encoder.to(device)
    decoder.to(device)

    return encoder,decoder
Ejemplo n.º 10
0
    if checkpoints:
        for cp in checkpoints:
            name, num = cp[:-4].split('_')
            num = int(num)
            if name == model_name and model_idx == num:
                state_dict = torch.load(
                    'checkpoint/{}_{}.tar'.format(model_name, num))
                encoder.load_state_dict(state_dict['encoder_state_dict'])
                decoder.load_state_dict(state_dict['decoder_state_dict'])
                #optimizer.load_state_dict(state_dict['optimizer_state_dict'])
                print('model_{}_{} is being used'.format(name,state_dict['epoch']))
                break 

    # test
    decoder.eval()
    encoder.eval()

    with torch.no_grad():
        all_ref = []
        all_pred = []
        #print('to device finish')
        for i, (images, batch_captions) in enumerate(BLEU4loader):
            if i >= 40:
                continue
            all_ref.extend(batch_captions)
            images = images.to(device)
            #all_ref.extend(batch_captions)
            
            # Generate an caption from the image
            feature = encoder(images)
Ejemplo n.º 11
0
def _main():
    parser = argparse.ArgumentParser()
    parser.add_argument("filename", help="(optional) path to photograph, for which a caption will be generated", nargs = "?")
    parser.add_argument("--host", help="(optional) host to start a webserver on. Default: 0.0.0.0", nargs = "?", default = "0.0.0.0")
    parser.add_argument("--port", help="(optional) port to start a webserver on. http://hostname:port/query", nargs = "?", type = int, default = 1985)
    parser.add_argument("--verbose", "-v", help="print verbose query information", action="store_true")
   
    global _args
    _args = parser.parse_args()

    if not _args.filename and not _args.port:
        parser.print_help()
        sys.exit(-1)

    global _device
    _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("PyTorch device = ", _device)

    # Load the vocabulary dictionary
    vocab_threshold = None,
    vocab_file = "./vocab.pkl"
    start_word = "<start>"
    end_word   = "<end>"
    unk_word   = "<unk>"
    load_existing_vocab = True
    #annotations_file = "/opt/cocoapi/annotations/captions_train2014.json"
    annotations_file = None

    print("Loading vocabulary...")
    global _vocab
    _vocab = Vocabulary(vocab_threshold, vocab_file, start_word, end_word, unk_word, annotations_file, load_existing_vocab)
    vocab_size = len (_vocab)
    print("Vocabulary contains %d words" % vocab_size)

    # Load pre-trained models: 
    # encoder (Resnet + embedding layers)
    # decoder (LSTM)
    global _encoder
    global _decoder
    encoder_path = os.path.join("./models/", _encoder_file)
    decoder_path = os.path.join("./models/", _decoder_file)
    print("Loading ", encoder_path)
    _encoder = EncoderCNN(_embed_size)
    _encoder.load_state_dict(torch.load(encoder_path))
    _encoder.eval()
    _encoder.to(_device)

    print("Loading ", decoder_path)
    _decoder = DecoderRNN(_embed_size, _hidden_size, vocab_size, _num_layers)
    _decoder.load_state_dict(torch.load(decoder_path))
    _decoder.eval()
    _decoder.to(_device)

    # Caption the photo, or start a web server if no photo specified
    if _args.filename:
        _get_prediction_from_file(_args.filename)
    else:
        global _app
        global _api

        _app = Flask(__name__)
        _api = Api(_app)

        _api.add_resource(ImageCaptionResource,
                "/v1/caption",
                "/v1/caption/")
        _app.run(host = _args.host, port = _args.port)
Ejemplo n.º 12
0
def main():

    #write predicted caption
    if not os.path.exists(args['generate_caption_path']):
        os.makedirs(args['generate_caption_path'])

    caption_string = os.path.join(args['generate_caption_path'], "caption_ncrt_class5.txt")   
    #mode = "a" if os.path.exists(caption_string) else "w"
    fp =open(caption_string, "w+")
    
    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize((0.9638, 0.9638, 0.9638), 
                             (0.1861, 0.1861, 0.1861))])
    
    # Load vocabulary wrapper
    with open(args['vocab_path'], 'rb') as f:
        vocab = pickle.load(f)

    # Build Models
    encoder = EncoderCNN(args['embed_size'])
    encoder.eval()  # evaluation mode (BN uses moving mean/variance)
    decoder = DecoderRNN(args['embed_size'], args['hidden_size'], 
                         len(vocab), args['num_layers'], max_seq_length=50)
    decoder.eval()
    

    # Load the trained model parameters
    encoder.load_state_dict(torch.load(args['encoder_path']))
    decoder.load_state_dict(torch.load(args['decoder_path']))
    
    # If use gpu
    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()
    
    # Prepare Image
    image_dir = args['image_path']
    images = os.listdir(image_dir)
    i = 1
    for image_id in images:
        #print('i->',i)
        #i = i+1  
        if not image_id.endswith('.jpg'):
            continue
        image = os.path.join(image_dir, image_id)
        image = load_image(image, transform)
        image_tensor = image.cuda()
        
        # Generate caption from image
        try:
            feature, cnn_features = encoder(image_tensor)
            sampled_ids = decoder.sample(feature, cnn_features)
            sampled_ids = sampled_ids.cpu().data.numpy()
        except:
              continue
        #print('image_ids->',image_id)      
        # Decode word_ids to words
        sampled_caption = []
        for word_id in sampled_ids:
            word = vocab.idx2word[word_id]
            sampled_caption.append(word)
            if word == '<end>':
                break
        sentence = ' '.join(sampled_caption)
        print ('i->', i, image_id + '\t' + sentence)
        fp.write(image_id)
        fp.write('\t')
        fp.write(sentence)
        if i<398:
           fp.write("\n")
        i = i+1         
        
    fp.close()
Ejemplo n.º 13
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # make sure vocab exists
    word2idx, idx2word = util.read_vocab_pickle(args.vocab_path)
    # will be used in embedder
    # bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_max_seq_len = 100

    # # create data loader. the data will be in decreasing order of length
    # data_loader = get_poem_poem_dataset(args.batch_size, shuffle=True, num_workers=args.num_workers, json_obj=unim,
    #                                     max_seq_len=bert_max_seq_len, word2idx=word2idx, tokenizer=bert_tokenizer)

    # init encode & decode model
    encoder = PoemImageEmbedModel(device)
    encoder = DataParallel(encoder)
    encoder.load_state_dict(torch.load(args.encoder_path))
    encoder = encoder.module.img_embedder.to(device)

    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(word2idx),
                         device).to(device)
    decoder = DataParallel(decoder)
    decoder.load_state_dict(torch.load(args.load))
    decoder = decoder.to(device)
    decoder.eval()

    with open('data/multim_poem.json') as f, open(
            'data/unim_poem.json') as unif:
        multim = json.load(f)
        unim = json.load(unif)

    with open('data/poem_features.pkl', 'rb') as f:
        poem_features = pickle.load(f)

    with open('data/img_features.pkl', 'rb') as f:
        img_features = pickle.load(f)

    word2idx, idx2word = util.read_vocab_pickle(args.vocab_path)

    examples = [
        img_features[3], img_features[10], img_features[11], img_features[12],
        img_features[13], img_features[14], img_features[15], img_features[16],
        img_features[17], img_features[18]
    ]
    for i, feature in enumerate(examples):
        print(i)
        feature = torch.tensor(feature).unsqueeze(0).to(device)
        sample_ids = decoder.module.sample_beamsearch(feature,
                                                      args.beamsize,
                                                      args.k,
                                                      temperature=args.temp)
        result = []
        for word_idx in sample_ids:
            word = idx2word[word_idx.item()]
            if word == ';':
                word = ';\n'
            elif word == '<EOS>':
                break
            elif word == '<SOS>':
                continue
            result.append(word)
        print(" ".join(result))
        print()

    test_images = glob2.glob('data/test_image_random/*.jp*g')
    test_images.sort()
    for test_image in test_images:
        print('img', test_image)
        sample_ids = util.generate_from_one_img_lstm(test_image, device,
                                                     encoder, decoder,
                                                     args.beamsize, args.k,
                                                     args.temp)
        result = []
        for word_idx in sample_ids:
            word = idx2word[word_idx.item()]
            if word == ';':
                word = ';\n'
            elif word == '<EOS>':
                break
            elif word == '<SOS>':
                continue
            result.append(word)
        print(" ".join(result))
        print()
Ejemplo n.º 14
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    #Load vocab_list for uniskip
    vocab_list = pd.read_csv("./data/vocab_list.csv", header=None)
    vocab_list = vocab_list.values.tolist()[0]

    #Build data loader
    data_loader = get_loader(args.image_dir,
                             args.img_embeddings_dir,
                             args.data_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # Build the models
    #im_encoder = preprocess_get_model.model()
    attention = T_Att()
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers, args.dropout)
    uniskip = UniSkip('./data/skip-thoughts', vocab_list)
    decoder.eval()

    if torch.cuda.is_available():
        #im_encoder.cuda()
        attention.cuda()
        decoder.cuda()
        uniskip.cuda()

    attention.load_state_dict(torch.load(args.attention_path))
    decoder.load_state_dict(torch.load(args.decoder_path))

    for i, (images, captions, cap_lengths, qa, qa_lengths,
            vocab_words) in enumerate(data_loader):

        #         # Set mini-batch dataset
        img_embeddings = to_var(images.data, volatile=True)
        captions = to_var(captions)
        #         qa = to_var(qa)
        #         targets = pack_padded_sequence(qa, qa_lengths, batch_first=True)[0]

        #         # Forward, Backward and Optimize
        #         decoder.zero_grad()
        #         attention.zero_grad()
        #         #features = encoder(images)

        #img_embeddings = im_encoder(images)
        #uniskip = UniSkip('/Users/tushar/Downloads/code/data/skip-thoughts', vocab_list)
        cap_embeddings = uniskip(captions, cap_lengths)
        cap_embeddings = cap_embeddings.data
        img_embeddings = img_embeddings.data
        ctx_vec = attention(img_embeddings, cap_embeddings)
        outputs = decoder.sample(ctx_vec)
        output_ids = outputs.cpu().data.numpy()
        qa = qa.numpy()
        qa = qa[0]

        #     predicted_q = []
        #     predicted_a = []
        sample = []
        #     flag = -1
        for word_id in output_ids:
            word = vocab.idx2word[word_id]
            sample.append(word)
        #    if word == '<end>':
        #        if flag == -1:
        #            predicted_q = sample
        #            sample = []
        #            flag = 0
        #        else:
        #            predicted_a = sample
        # predicted_q = ' '.join(predicted_q[1:])
        # predicted_a = ' '.join(predicted_a[1:])
        sample = ' '.join(sample)
        actual = []
        # print("predicted q was : " + predicted_q)
        for word_id in qa:
            word = vocab.idx2word[word_id]
            actual.append(word)
        actual = ' '.join(actual)
        #print(im_id)
        print("actual_qa : " + actual + " | predicted_qa : " + sample)
Ejemplo n.º 15
0
def main():
    ####################################################
    # config
    ####################################################
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config = {}
    config['dataset'] = 'COCO'
    config[
        'vocab_word2idx_path'] = './vocab/save/' + 'COCO' + '/vocab/' + 'thre5_word2idx.pkl'
    config[
        'vocab_idx2word_path'] = './vocab/save/' + 'COCO' + '/vocab/' + 'thre5_idx2word.pkl'
    config[
        'vocab_idx_path'] = './vocab/save/' + 'COCO' + '/vocab/' + 'thre5_idx.pkl'
    config['crop_size'] = 224
    config['images_root'] = './data/COCO/train2014_resized'
    config[
        'json_file_path_train'] = './data/COCO/annotations/captions_mini100.json'
    config[
        'json_file_path_val'] = './data/COCO/annotations/captions_val2014.json'
    config['batch_size'] = 128
    config['embed_size'] = 256
    config['hidden_size'] = 512
    config['learning_rate'] = 1e-4
    config['epoch_num'] = 20
    config['save_step'] = 10
    config['model_save_root'] = './save/'

    config['encoder_path'] = './save/'
    config['decoder_path'] = './save/'

    ####################################################
    # load vocabulary
    ####################################################
    vocab = Vocabulary()
    with open(config['vocab_word2idx_path'], 'rb') as f:
        vocab.word2idx = pickle.load(f)
    with open(config['vocab_idx2word_path'], 'rb') as f:
        vocab.idx2word = pickle.load(f)
    with open(config['vocab_idx_path'], 'rb') as f:
        vocab.idx = pickle.load(f)

    ####################################################
    # create data_loader
    ####################################################
    normalize = {
        'Flickr8k': [(0.4580, 0.4464, 0.4032), (0.2318, 0.2229, 0.2269)],
        'Flickr30k': None,
        'COCO': [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]
    }

    transform = transforms.Compose([
        transforms.RandomCrop(config['crop_size']),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(normalize[config['dataset']][0],
                             normalize[config['dataset']][1])
    ])

    loader_train = get_loader(dataset_name=config['dataset'],
                              images_root=config['images_root'],
                              json_file_path=config['json_file_path_train'],
                              vocab=vocab,
                              transform=transform,
                              batch_size=config['batch_size'],
                              shuffle=True,
                              is_train=True)
    loader_val = get_loader(dataset_name=config['dataset'],
                            images_root=config['images_root'],
                            json_file_path=config['json_file_path_val'],
                            vocab=vocab,
                            transform=transform,
                            batch_size=1,
                            shuffle=False,
                            is_val=True)

    ####################################################
    # create model
    ####################################################
    encoder = EncoderCNN(config['embed_size'])
    decoder = DecoderRNN(config['embed_size'], config['hidden_size'],
                         len(vocab), 1)
    encoder.load_state_dict(torch.load(config['encoder_path']))
    decoder.load_state_dict(torch.load(config['decoder_path']))
    encoder.to(device)
    decoder.to(device)

    ####################################################
    # create trainer
    ####################################################
    raw_captions = []
    sampled_captions = []

    encoder.eval()
    decoder.eval()
    for i, (image, caption, length) in enumerate(tqdm(loader_val)):
        image = image.to(device)
        feature = encoder(image)
        sampled_ids = decoder.sample(feature)
        sampled_ids = sampled_ids[0].cpu().numpy()
        sampled_caption = []
        for word_id in sampled_ids:
            word = vocab.idx2word[word_id]
            sampled_caption.append(word)
            if word == '<END>':
                break
        raw_caption = [[vocab(int(token)) for token in list(caption[0])]]
        sampled_caption = sampled_caption[1:-1]  # delete <START> and <END>
        # if sampled_caption[-1] != '.':
        #     sampled_caption.append('.')
        raw_caption[0] = raw_caption[0][1:-1]  # delete <START> and <END>
        raw_captions.append(raw_caption)
        sampled_captions.append(sampled_caption)

    hypo = {}
    for i, caption in enumerate(sampled_captions):
        hypo[i] = [' '.join(caption)]
    ref = {}
    for i, caption in enumerate(raw_captions):
        ref[i] = [' '.join(caption[0])]

    final_scores = Bleu().compute_score(ref, hypo)
    print(final_scores[0])
Ejemplo n.º 16
0
class ImageDescriptor():
    def __init__(self, args, encoder):
        assert(args.mode == 'train' or 'val' or 'test')
        self.__args = args
        self.__mode = args.mode
        self.__attention_mechanism = args.attention
        self.__stats_manager = ImageDescriptorStatsManager()
        self.__validate_when_training = args.validate_when_training
        self.__history = []

        if not os.path.exists(args.model_dir):
            os.makedirs(args.model_dir)

        self.__config_path = os.path.join(
            args.model_dir, f'config-{args.encoder}{args.encoder_ver}.txt')

        # Device configuration
        self.__device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # training set vocab
        with open(args.vocab_path, 'rb') as f:
            self.__vocab = pickle.load(f)

        # validation set vocab
        with open(args.vocab_path.replace('train', 'val'), 'rb') as f:
            self.__vocab_val = pickle.load(f)

        # coco dataset
        self.__coco_train = CocoDataset(
            args.image_dir, args.caption_path, self.__vocab, args.crop_size)
        self.__coco_val = CocoDataset(
            args.image_dir, args.caption_path.replace('train', 'val'), self.__vocab_val, args.crop_size)

        # data loader
        self.__train_loader = torch.utils.data.DataLoader(dataset=self.__coco_train,
                                                          batch_size=args.batch_size,
                                                          shuffle=True,
                                                          num_workers=args.num_workers,
                                                          collate_fn=collate_fn)
        self.__val_loader = torch.utils.data.DataLoader(dataset=self.__coco_val,
                                                        batch_size=args.batch_size,
                                                        shuffle=False,
                                                        num_workers=args.num_workers,
                                                        collate_fn=collate_fn)
        # Build the models
        self.__encoder = encoder.to(self.__device)
        self.__decoder = DecoderRNN(args.embed_size, args.hidden_size,
                                    len(self.__vocab), args.num_layers, attention_mechanism=self.__attention_mechanism).to(self.__device)

        # Loss and optimizer
        self.__criterion = nn.CrossEntropyLoss()
        self.__params = list(self.__decoder.parameters(
        )) + list(self.__encoder.linear.parameters()) + list(self.__encoder.bn.parameters())
        self.__optimizer = torch.optim.Adam(
            self.__params, lr=args.learning_rate)

        # Load checkpoint and check compatibility
        if os.path.isfile(self.__config_path):
            with open(self.__config_path, 'r') as f:
                content = f.read()[:-1]
            if content != repr(self):
                # save the error info
                with open('config.err', 'w') as f:
                    print(f'f.read():\n{content}', file=f)
                    print(f'repr(self):\n{repr(self)}', file=f)
                raise ValueError(
                    "Cannot create this experiment: "
                    "I found a checkpoint conflicting with the current setting.")
            self.load(file_name=args.checkpoint)
        else:
            self.save()

    def setting(self):
        '''
        Return the setting of the experiment.
        '''
        return {'Net': (self.__encoder, self.__decoder),
                'Optimizer': self.__optimizer,
                'BatchSize': self.__args.batch_size}

    @property
    def epoch(self):
        return len(self.__history)

    @property
    def history(self):
        return self.__history

    # @property
    # def mode(self):
    #     return self.__args.mode

    # @mode.setter
    # def mode(self, m):
    #     self.__args.mode = m

    def __repr__(self):
        '''
        Pretty printer showing the setting of the experiment. This is what
        is displayed when doing `print(experiment). This is also what is
        saved in the `config.txt file.
        '''
        string = ''
        for key, val in self.setting().items():
            string += '{}({})\n'.format(key, val)
        return string

    def state_dict(self):
        '''
        Returns the current state of the model.
        '''
        return {'Net': (self.__encoder.state_dict(), self.__decoder.state_dict()),
                'Optimizer': self.__optimizer.state_dict(),
                'History': self.__history}

    def save(self):
        '''
        Saves the model on disk, i.e, create/update the last checkpoint.
        '''
        file_name = os.path.join(
            self.__args.model_dir, '{}{}-epoch-{}.ckpt'.format(self.__args.encoder, self.__args.encoder_ver, self.epoch))
        torch.save(self.state_dict(), file_name)
        with open(self.__config_path, 'w') as f:
            print(self, file=f)

        print(f'Save to {file_name}.')

    def load(self, file_name=None):
        '''
        Loads the model from the last checkpoint saved on disk.

        Args:
            file_name (str): path to the checkpoint file
        '''
        if not file_name:
            # find the latest .ckpt file
            try:
                file_name = max(
                    glob.iglob(os.path.join(self.__args.model_dir, '*.ckpt')), key=os.path.getctime)
                print(f'Load from {file_name}.')
            except:
                raise FileNotFoundError(
                    'No checkpoint file in the model directory.')
        else:
            file_name = os.path.join(self.__args.model_dir, file_name)
            print(f'Load from {file_name}.')

        try:
            checkpoint = torch.load(file_name, map_location=self.__device)
        except:
            raise FileNotFoundError(
                'Please check --checkpoint, the name of the file')

        self.load_state_dict(checkpoint)
        del checkpoint

    def load_state_dict(self, checkpoint):
        '''
        Loads the model from the input checkpoint.

        Args:
            checkpoint: an object saved with torch.save() from a file.
        '''
        self.__encoder.load_state_dict(checkpoint['Net'][0])
        self.__decoder.load_state_dict(checkpoint['Net'][1])
        self.__optimizer.load_state_dict(checkpoint['Optimizer'])
        self.__history = checkpoint['History']

        # The following loops are used to fix a bug that was
        # discussed here: https://github.com/pytorch/pytorch/issues/2830
        # (it is supposed to be fixed in recent PyTorch version)
        for state in self.__optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(self.__device)

    def train(self, plot_loss=None):
        '''
        Train the network using backpropagation based
        on the optimizer and the training set.

        Args:
            plot_loss (func, optional): if not None, should be a function taking a
                single argument being an experiment (meant to be `self`).
                Similar to a visitor pattern, this function is meant to inspect
                the current state of the experiment and display/plot/save
                statistics. For example, if the experiment is run from a
                Jupyter notebook, `plot` can be used to display the evolution
                of the loss with `matplotlib`. If the experiment is run on a
                server without display, `plot` can be used to show statistics
                on `stdout` or save statistics in a log file. (default: None)
        '''
        self.__encoder.train()
        self.__decoder.train()
        self.__stats_manager.init()
        total_step = len(self.__train_loader)
        start_epoch = self.epoch
        print("Start/Continue training from epoch {}".format(start_epoch))

        if plot_loss is not None:
            plot_loss(self)

        for epoch in range(start_epoch, self.__args.num_epochs):
            t_start = time.time()
            self.__stats_manager.init()
            for i, (images, captions, lengths) in enumerate(self.__train_loader):
                # Set mini-batch dataset
                if not self.__attention_mechanism:
                    images = images.to(self.__device)
                    captions = captions.to(self.__device)
                else:
                    with torch.no_grad():
                        images = images.to(self.__device)
                    captions = captions.to(self.__device)

                targets = pack_padded_sequence(
                    captions, lengths, batch_first=True)[0]

                # Forward, backward and optimize
                if not self.__attention_mechanism:
                    features = self.__encoder(images)
                    outputs = self.__decoder(features, captions, lengths)
                    self.__decoder.zero_grad()
                    self.__encoder.zero_grad()
                else:
                    self.__encoder.zero_grad()
                    self.__decoder.zero_grad()
                    features, cnn_features = self.__encoder(images)
                    outputs = self.__decoder(
                        features, captions, lengths, cnn_features=cnn_features)
                loss = self.__criterion(outputs, targets)

                loss.backward()
                self.__optimizer.step()
                with torch.no_grad():
                    self.__stats_manager.accumulate(
                        loss=loss.item(), perplexity=np.exp(loss.item()))

                # Print log info each iteration
                if i % self.__args.log_step == 0:
                    print('[Training] Epoch: {}/{} | Step: {}/{} | Loss: {:.4f} | Perplexity: {:5.4f}'
                          .format(epoch+1, self.__args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))

            if not self.__validate_when_training:
                self.__history.append(self.__stats_manager.summarize())
                print("Epoch {} | Time: {:.2f}s\nTraining Loss: {:.6f} | Training Perplexity: {:.6f}".format(
                    self.epoch, time.time() - t_start, self.__history[-1]['loss'], self.__history[-1]['perplexity']))
            else:
                self.__history.append(
                    (self.__stats_manager.summarize(), self.evaluate()))
                print("Epoch {} | Time: {:.2f}s\nTraining Loss: {:.6f} | Training Perplexity: {:.6f}\nEvaluation Loss: {:.6f} | Evaluation Perplexity: {:.6f}".format(
                    self.epoch, time.time() - t_start,
                    self.__history[-1][0]['loss'], self.__history[-1][0]['perplexity'],
                    self.__history[-1][1]['loss'], self.__history[-1][1]['perplexity']))

            # Save the model checkpoints
            self.save()

            if plot_loss is not None:
                plot_loss(self)

        print("Finish training for {} epochs".format(self.__args.num_epochs))

    def evaluate(self, print_info=False):
        '''
        Evaluates the experiment, i.e., forward propagates the validation set
        through the network and returns the statistics computed by the stats
        manager.

        Args:
            print_info (bool): print the results of loss and perplexity
        '''
        self.__stats_manager.init()
        self.__encoder.eval()
        self.__decoder.eval()
        total_step = len(self.__val_loader)
        with torch.no_grad():
            for i, (images, captions, lengths) in enumerate(self.__val_loader):
                images = images.to(self.__device)
                captions = captions.to(self.__device)
                targets = pack_padded_sequence(
                    captions, lengths, batch_first=True)[0]

                # Forward
                if not self.__attention_mechanism:
                    features = self.__encoder(images)
                    outputs = self.__decoder(features, captions, lengths)
                else:
                    features, cnn_features = self.__encoder(images)
                    outputs = self.__decoder(
                        features, captions, lengths, cnn_features=cnn_features)
                loss = self.__criterion(outputs, targets)
                self.__stats_manager.accumulate(
                    loss=loss.item(), perplexity=np.exp(loss.item()))
                if i % self.__args.log_step == 0:
                    print('[Validation] Step: {}/{} | Loss: {:.4f} | Perplexity: {:5.4f}'
                          .format(i, total_step, loss.item(), np.exp(loss.item())))

        summarize = self.__stats_manager.summarize()
        if print_info:
            print(
                f'[Validation] Average loss for this epoch is {summarize["loss"]:.6f}')
            print(
                f'[Validation] Average perplexity for this epoch is {summarize["perplexity"]:.6f}\n')
        self.__encoder.train()
        self.__decoder.train()
        return summarize

    def mode(self, mode=None):
        '''
        Get the current mode or change mode.

        Args:
            mode (str): 'train' or 'eval' mode
        '''
        if not mode:
            return self.__mode
        self.__mode = mode

    def __load_image(self, image):
        '''
        Load image at `image_path` for evaluation.

        Args:
            image (PIL Image): image
        '''
        image = image.resize([224, 224], Image.LANCZOS)

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])
        image = transform(image).unsqueeze(0)

        return image

    def test(self, image_path=None, plot=False):
        '''
        Evaluate the model by generating the caption for the
        corresponding image at `image_path`.

        Note: This function will not provide BLEU socre.

        Args:
            image_path (str): file path of the evaluation image
            plot (bool): plot or not
        '''
        self.__encoder.eval()
        self.__decoder.eval()

        with torch.no_grad():
            if not image_path:
                image_path = self.__args.image_path

            image = Image.open(image_path)

            # only process with RGB image
            if np.array(image).ndim == 3:
                img = self.__load_image(image).to(self.__device)

                # generate an caption
                if not self.__attention_mechanism:
                    feature = self.__encoder(img)
                    sampled_ids = self.__decoder.sample(feature)
                    sampled_ids = sampled_ids[0].cpu().numpy()
                else:
                    feature, cnn_features = self.__encoder(img)
                    sampled_ids = self.__decoder.sample(feature, cnn_features)
                    sampled_ids = sampled_ids.cpu().data.numpy()

                # Convert word_ids to words
                sampled_caption = []
                for word_id in sampled_ids:
                    word = self.__vocab.idx2word[word_id]
                    sampled_caption.append(word)
                    if word == '<end>':
                        break
                sentence = ' '.join(sampled_caption[1:-1])

                # Print out the image and the generated caption
                print(sentence)

                if plot:
                    image = Image.open(image_path)
                    plt.imshow(np.asarray(image))
            else:
                print('Not support for non-RGB image.')
        self.__encoder.train()
        self.__decoder.train()

    def coco_image(self, idx, ds='val'):
        '''
        Access iamge_id (which is part of the file name) 
        and corresponding image caption of index `idx` in COCO dataset.

        Note: For jupyter notebook

        Args:
            idx (int): index of COCO dataset

        Returns:
            (dict)
        '''
        assert(ds == 'train' or 'val')

        if ds == 'train':
            ann_id = self.__coco_train.ids[idx]
            return self.__coco_train.coco.anns[ann_id]
        else:
            ann_id = self.__coco_val.ids[idx]
            return self.__coco_val.coco.anns[ann_id]

    @property
    def len_of_train_set(self):
        '''
        Number of training 
        '''
        return len(self.__coco_train)

    @property
    def len_of_val_set(self):
        return len(self.__coco_val)

    def bleu_score(self, idx, ds='val', plot=False, show_caption=False):
        '''
        Evaluate the BLEU score for index `idx` in COCO dataset.

        Note: For jupyter notebook

        Args:
            idx (int): index
            ds (str): training or validation dataset
            plot (bool): plot the image or not

        Returns:
            score (float): bleu score
        '''
        assert(ds == 'train' or 'val')
        self.__encoder.eval()
        self.__decoder.eval()

        with torch.no_grad():
            try:
                if ds == 'train':
                    ann_id = self.__coco_train.ids[idx]
                    coco_ann = self.__coco_train.coco.anns[ann_id]
                else:
                    ann_id = self.__coco_val.ids[idx]
                    coco_ann = self.__coco_val.coco.anns[ann_id]
            except:
                raise IndexError('Invalid index')

            image_id = coco_ann['image_id']

            image_id = str(image_id)
            if len(image_id) != 6:
                for _ in range(6 - len(image_id)):
                    image_id = '0' + image_id

            image_path = f'{self.__args.image_dir}/COCO_train2014_000000{image_id}.jpg'
            if ds == 'val':
                image_path = image_path.replace('train', 'val')

            coco_list = coco_ann['caption'].split()

            image = Image.open(image_path)

            if np.array(image).ndim == 3:
                img = self.__load_image(image).to(self.__device)

                # generate an caption
                if not self.__attention_mechanism:
                    feature = self.__encoder(img)
                    sampled_ids = self.__decoder.sample(feature)
                    sampled_ids = sampled_ids[0].cpu().numpy()
                else:
                    feature, cnn_features = self.__encoder(img)
                    sampled_ids = self.__decoder.sample(feature, cnn_features)
                    sampled_ids = sampled_ids.cpu().data.numpy()

                # Convert word_ids to words
                sampled_caption = []
                for word_id in sampled_ids:
                    word = self.__vocab.idx2word[word_id]
                    sampled_caption.append(word)
                    if word == '<end>':
                        break

                # strip punctuations and spacing
                sampled_list = [c for c in sampled_caption[1:-1]
                                if c not in punctuation]

                score = sentence_bleu(coco_list, sampled_list,
                                      smoothing_function=SmoothingFunction().method4)

                if plot:
                    plt.figure()
                    image = Image.open(image_path)
                    plt.imshow(np.asarray(image))
                    plt.title(f'score: {score}')
                    plt.xlabel(f'file: {image_path}')

                # Print out the generated caption
                if show_caption:
                    print(f'Sampled caption:\n{sampled_list}')
                    print(f'COCO caption:\n{coco_list}')

            else:
                print('Not support for non-RGB image.')
                return

        return score
Ejemplo n.º 17
0
vocab_size = len(vocab)

# Initialize the encoder and decoder. 
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

device = torch.device("cpu")
# encoder.to(device)
# decoder.to(device)

# Load the pretrained model
encoder.load_state_dict(torch.load(PRETRAINED_MODEL_PATH.format('encoder')))
decoder.load_state_dict(torch.load(PRETRAINED_MODEL_PATH.format('decoder')))

encoder.eval()
decoder.eval()

images, conv_images = next(iter(data_loader))
features = encoder(conv_images).unsqueeze(1)
output = decoder.sample(features, max_len=max_len)

word_list = []
for word_idx in output:
    if word_idx == vocab.word2idx[vocab.start_word]:
        continue
    if word_idx == vocab.word2idx[vocab.end_word]:
        break
    word_list.append(vocab.idx2word[word_idx])

print(' '.join(word_list))
plt.imshow(np.squeeze(images))
Ejemplo n.º 18
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    print("length", len(vocab))

    # Build data loader
    # data_loader = get_loader(args.image_dir, args.caption_path, vocab,
    #                             args.dictionary, args.batch_size,
    #                             shuffle=True, num_workers=args.num_workers)
    data = input("Enter Topic: ")
    # Build the models
    #encoder = EncoderCNN(args.embed_size).to(device)
    dictionary = pd.read_csv(args.dictionary, header=0,encoding = 'unicode_escape',error_bad_lines=False)
    dictionary = list(dictionary['keys'])

    decoder = DecoderRNN(len(dictionary), args.hidden_size, len(vocab), args.num_layers).to(device)

    decoder.load_state_dict(torch.load(args.model_path, map_location=device))
    decoder.eval()


    # Train the models
    # total_step = len(data_loader)
    # for epoch in range(args.num_epochs):
    # for i, (array, captions, lengths) in enumerate(data_loader):
    array = torch.zeros((len(dictionary)))
    for val in data.split():
        # Set mini-batch dataset
        array[dictionary.index(val)] = 1
        # print("In sample", array)
    array = (array, )
    array = torch.stack(array, 0)
    array = array.to(device)
    # print("After", array)
    #captions = captions.to(device)
    # targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

    # Forward, backward and optimize
    #features = encoder(images)
    outputs = decoder.sample(array)

    count = 0
    sentence = ''
    for i in range(len(outputs)):
        sampled_ids = outputs[i].cpu().numpy()          # (1, max_seq_length) -> (max_seq_length)

        # Convert word_ids to words
        sampled_caption = []
        for word_id in sampled_ids:
            count += 1
            word = vocab.idx2word[word_id]
            sampled_caption.append(word)
            if word == '<end>':
                break
        sentence = sentence.join(' ')
        sentence = sentence.join(sampled_caption)

        # Print out the image and the generated caption
    print (sentence)

    print(count)
Ejemplo n.º 19
0
#====================================================================================================


def evaluate_sample(p, t):
    for i in range(len(t)):
        if t[i] != p[i]:
            return 0
        if t[i] == 17.:
            break
    return 1


#====================================================================================================
force_encoder_model.eval()
rgb_mask_encoder_model.eval()
decoder_model.eval()

correct = 0

for batch_idx, (rgb_mask_img, force_img,
                targets) in enumerate(tqdm(testloader)):

    rgb_mask_img = rgb_mask_img.to(device)
    force_img = force_img.to(device)
    targets = targets.type(torch.LongTensor).to(device)

    force_feature = force_encoder_model(force_img)
    rgb_mask_feature = rgb_mask_encoder_model(rgb_mask_img)
    predictions = decoder_model(force_feature, rgb_mask_feature)
    # test loss and accuracy
    pred = predictions.max(1, keepdim=True)[1]
Ejemplo n.º 20
0
def main(args):

    with open(
            args.vocab_path, 'rb'
    ) as f:  #in build_vocab function,pickle.dump(f),then here load equals a dict vocab
        vocab = pickle.load(f)

    # Load vocabulary wrapper
    vocab_image = vocab
    # Build models
    encoder = EncoderGGNN(len(vocab_image), 256).to(
        device)  # eval mode (batchnorm uses moving mean/variance)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers).to(device)
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             args.relationship_path,
                             vocab,
                             vocab_image,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             ids=438430)
    criterion = nn.CrossEntropyLoss()

    # Test the models
    total_step = len(data_loader)  # numbers of batchs
    print('There is total {} batch in test data set\n'.format(total_step))
    for epoch in range(1, args.num_epochs):
        encoder.load_state_dict(
            torch.load('./models/encoder-{}.ckpt'.format(epoch)))
        decoder.load_state_dict(
            torch.load('./models/decoder-{}.ckpt'.format(epoch)))
        with torch.no_grad():
            encoder.eval()
            decoder.eval()
            f_samp = codecs.open('./candidate.txt', 'w', encoding='utf-8')
            f_ref = codecs.open('./reference.txt', 'w', encoding='utf-8')
            sum_loss = 0.0
            for i, (images, lengths_images, captions, lengths,
                    adjmatrixs) in enumerate(data_loader):
                images = images.to(device)
                captions = captions.to(device)
                adjmatrixs = adjmatrixs.to(device)
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]
                features = encoder(images, adjmatrixs, lengths_images)
                outputs = decoder(features, captions, lengths)
                loss = criterion(outputs, targets)
                sum_loss += loss
                sampled_ids = decoder.sample(features)
                #print(sampled_ids.shape)  # torch.Size([512, 20])observe the shape[0] whether is the batch
                #print(type(sampled_ids))  # <class 'torch.Tensor'>
                #print(captions.shape)     # torch.Size([512, 12])
                #print(type(captions))     # <class 'torch.Tensor'>
                #samp_sentences = []
                #ref_sentences = []
                #print('Traslation')
                for j in range(
                        len(sampled_ids)):  # len(sampled_ids) is  batch_size
                    sampled_id = sampled_ids[j]
                    sampled_id = sampled_id.cpu().numpy()
                    sampled_caption = []
                    for word_id in sampled_id:  # word_id is a np.int64 scalar
                        word = vocab.idx2word[word_id]
                        sampled_caption.append(word)
                        if word == '<end>':
                            break
                    sentence = ' '.join(sampled_caption)
                    f_samp.write(sentence + ' . \n')
                    #samp_sentences.append(sentence)

                    caption_len = lengths[j]
                    caption = captions[j].cpu().numpy()
                    ref_caption = []
                    for l in range(caption_len):
                        word_id = caption[l]
                        word = vocab.idx2word[word_id]
                        ref_caption.append(word)
                    reference = ' '.join(ref_caption)
                    f_ref.write(reference + ' . \n')
                    #ref_sentences.append(reference)
                    #print the generated caption
        f_samp.close()
        f_ref.close()
        score, pn = BLEU('./candidate.txt', './reference.txt')
        sum_loss /= total_step
        print('loss is {:.6f}\tBLEU is {:.8f}\n'.format(sum_loss, score))
        writer.add_scalar('test/loss', sum_loss, epoch)
        writer.add_scalar('test/BLEU', score, epoch)
        writer.add_scalar('test/n-grams', {
            'bleu-1': pn[0],
            'bleu-2': pn[1],
            'bleu-3': pn[2],
            'bleu-4': pn[3]
        }, epoch)