示例#1
0
    def test_overfit(self):
        print("\n======== TestCaptioningRNN.test_overfit:")

        small_data = coco_utils.load_coco_data(max_train=50)

        small_rnn_model = captioning_rnn.CaptioningRNN(
            cell_type='rnn',
            word_to_idx=small_data['word_to_idx'],
            input_dim=small_data['train_features'].shape[1],
            hidden_dim=512,
            wordvec_dim=256)

        solv = captioning_solver.CaptioningSolver(
            small_rnn_model,
            small_data,
            update_rule='adam',
            num_epochs=50,
            batch_size=25,
            optim_config={'learning_rate': 5e-3},
            lr_decay=0.95,
            verbose=self.verbose,
            print_every=1)

        solv.train()

        if self.draw_figures:
            plt.plot(solv.loss_history)
            plt.xlabel('Iteration')
            plt.ylabel('Loss')
            plt.title("Training loss history")
            plt.show()

        print("======== TestCaptioningRNN.test_overfit: <END> ")
示例#2
0
def ex_caption_rnn(verbose=False, plot_figures=False):

    test_data = coco_utils.load_coco_data(max_train=50000)

    small_rnn_model = captioning_rnn.CaptioningRNN(
        cell_type='rnn',
        word_to_idx=test_data['word_to_idx'],
        input_dim=test_data['train_features'].shape[1],
        hidden_dim=512,
        wordvec_dim=256
    )

    solv = captioning_solver.CaptioningSolver(
        small_rnn_model,
        test_data,
        update_rule='adam',
        num_epochs=50,
        batch_size=25,
        optim_config={'learning_rate': 5e-3},
        lr_decay=0.95,
        verbose=verbose,
        print_every=100
    )

    solv.train()
    test_time_sampling(test_data,
                       small_rnn_model,
                       plot_figures=plot_figures)
    # Plot the loss...
    if plot_figures:
        plt.plot(solv.loss_history, 'o')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.title('Captioning RNN Loss')
        plt.show()
示例#3
0
def load_data(verbose=False):
    data = coco_utils.load_coco_data(pca_features=True)
    if verbose:
        for k, v in data.items():
            if type(v) == np.ndarray:
                print('%s : %s, %s, %%d' % (k, type(v), v.shape, v.dtype))
            else:
                print('%s : %s (%s)' % (k, type(v), len(v)))

    return data
示例#4
0
    def test_sampling(self):
        print("\n======== TestCaptioningRNN.test_sampling:")

        small_data = coco_utils.load_coco_data(max_train=50)

        small_rnn_model = captioning_rnn.CaptioningRNN(
            cell_type='rnn',
            word_to_idx=small_data['word_to_idx'],
            input_dim=small_data['train_features'].shape[1],
            hidden_dim=512,
            wordvec_dim=256)
        solv = captioning_solver.CaptioningSolver(
            small_rnn_model,
            small_data,
            update_rule='adam',
            num_epochs=50,
            batch_size=25,
            optim_config={'learning_rate': 5e-3},
            lr_decay=0.95,
            verbose=self.verbose,
            print_every=1)
        # Train the model
        solv.train()

        for split in ['train', 'val']:
            minibatch = coco_utils.sample_coco_minibatch(small_data,
                                                         split=split,
                                                         batch_size=8)
            gt_captions, features, urls = minibatch
            gt_captions = coco_utils.decode_captions(gt_captions,
                                                     small_data['idx_to_word'])

            sample_captions = small_rnn_model.sample(features)
            sample_captions = sample_captions.astype(np.int32)
            sample_captions = coco_utils.decode_captions(
                sample_captions, small_data['idx_to_word'])

            for gt_capt, samp_capt, url in zip(gt_captions, sample_captions,
                                               urls):
                plt.imshow(image_utils.image_from_url(url))
                plt.title('Split : %s\n%5s\nGT :%s' %
                          (split, samp_capt, gt_capt))
                plt.axis('off')
                plt.show()

        print("======== TestCaptioningRNN.test_sampling: <END> ")
示例#5
0
    def test_overfit_model(self):
        print("\n======== TestCaptioningLSTM.test_overfit_model:")

        small_data = coco_utils.load_coco_data(max_train=50)

        small_lstm_model = captioning_rnn.CaptioningRNN(
            cell_type='lstm',
            word_to_idx=small_data['word_to_idx'],
            input_dim=small_data['train_features'].shape[1],
            hidden_dim=512,
            wordvec_dim=256,
            dtype=np.float32)

        small_lstm_solv = captioning_solver.CaptioningSolver(
            small_lstm_model,
            small_data,
            update_rule='adam',
            num_epochs=50,
            batch_size=25,
            optim_config={'learning_rate': 5e-3},
            lr_decay=0.95,
            print_every=10,
            verbose=self.verbose)

        # Train
        small_lstm_solv.train()

        if self.draw_figures:
            #fig, ax = vis_solver.get_train_fig()
            #vis_solver.plot_solver(ax, small_lstm_solv)
            plt.plot(small_lstm_solv.loss_history)
            plt.xlabel('Iteration')
            plt.ylabel('Loss')
            plt.title('Training loss history')
            plt.show()

        print("\n======== TestCaptioningLSTM.test_overfit_model: <END>")