예제 #1
0
def overfit_small_data():
    """
    Similar to the Solver class that we used to train image classification models on the
    previous assignment, on this assignment we use a CaptioningSolver class to train
    image captioning models. Open the file cs231n/captioning_solver.py and read through
    the CaptioningSolver class; it should look very familiar.

    Once you have familiarized yourself with the API, run the following to make sure your
    model overfits a small sample of 100 training examples. You should see a final loss
    of less than 0.1.
    """
    np.random.seed(231)

    small_data = load_coco_data(max_train=50)

    small_rnn_model = CaptioningRNN(
        cell_type='rnn',
        word_to_idx=data['word_to_idx'],
        input_dim=data['train_features'].shape[1],
        hidden_dim=512,
        wordvec_dim=256,
    )
    small_rnn_solver = CaptioningSolver(
        small_rnn_model,
        small_data,
        update_rule='adam',
        num_epochs=50,
        batch_size=25,
        optim_config={
            'learning_rate': 5e-3,
        },
        lr_decay=0.95,
        verbose=True,
        print_every=10,
    )
    small_rnn_solver.train()

    # Plot the training losses
    plt.plot(small_rnn_solver.loss_history)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training loss history')
    plt.show()

    for split in ['train', 'val']:
        gt_captions, features, urls = sample_coco_minibatch(small_data,
                                                            split=split,
                                                            batch_size=2)
        gt_captions = decode_captions(gt_captions, data['idx_to_word'])

        sample_captions = small_rnn_model.sample(features)
        sample_captions = decode_captions(sample_captions, data['idx_to_word'])

        for gt_caption, sample_caption, url in zip(gt_captions,
                                                   sample_captions, urls):
            plt.imshow(image_from_url(url))
            plt.title('%s\n%s\nGT:%s' % (split, sample_caption, gt_caption))
            plt.axis('off')
            plt.show()
예제 #2
0
def overfit_lstm_captioning_model():
    """You should see a final loss less than 0.5."""
    np.random.seed(231)

    small_data = load_coco_data(max_train=50)

    small_lstm_model = CaptioningRNN(
        cell_type='lstm',
        word_to_idx=data['word_to_idx'],
        input_dim=data['train_features'].shape[1],
        hidden_dim=512,
        wordvec_dim=256,
        dtype=np.float32,
    )

    small_lstm_solver = CaptioningSolver(
        small_lstm_model,
        small_data,
        update_rule='adam',
        num_epochs=50,
        batch_size=25,
        optim_config={
            'learning_rate': 5e-3,
        },
        lr_decay=0.995,
        verbose=True,
        print_every=10,
    )

    small_lstm_solver.train()

    # Plot the training losses
    plt.plot(small_lstm_solver.loss_history)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training loss history')
    plt.show()

    for split in ['train', 'val']:
        minibatch = sample_coco_minibatch(small_data,
                                          split=split,
                                          batch_size=2)
        gt_captions, features, urls = minibatch
        gt_captions = decode_captions(gt_captions, data['idx_to_word'])

        sample_captions = small_lstm_model.sample(features)
        sample_captions = decode_captions(sample_captions, data['idx_to_word'])

        for gt_caption, sample_caption, url in zip(gt_captions,
                                                   sample_captions, urls):
            plt.imshow(image_from_url(url))
            plt.title('%s\n%s\nGT:%s' % (split, sample_caption, gt_caption))
            plt.axis('off')
            plt.show()
예제 #3
0
def demo(data, model):
    start = data['word_to_idx']['<START>']
    end = data['word_to_idx']['<END>']
    null = data['word_to_idx']['<NULL>']

    for split in ['train', 'val']:
        minibatch = sample_coco_minibatch(data, split=split, batch_size=2)
        gt_captions, features, urls = minibatch
        gt_captions = decode_captions(gt_captions, data['idx_to_word'])

        sample_captions = model.sample(features, start, end, null)
        sample_captions = decode_captions(sample_captions, data['idx_to_word'])

        for gt_caption, sample_caption, url in zip(gt_captions,
                                                   sample_captions, urls):
            plt.imshow(image_from_url(url))
            plt.title('%s\n%s\nGT:%s' % (split, sample_caption, gt_caption))
            plt.axis('off')
            plt.show()
예제 #4
0
#           batch_size=25,
#           optim_config={
#             'learning_rate': 5e-3,
#           },
#           lr_decay=0.995,
#           verbose=True, print_every=10,
#         )
#
#small_lstm_solver.train()
#
## Plot the training losses
#plt.plot(small_lstm_solver.loss_history)
#plt.xlabel('Iteration')
#plt.ylabel('Loss')
#plt.title('Training loss history')
#plt.show()

for split in ['train', 'val']:
    minibatch = sample_coco_minibatch(small_data, split=split, batch_size=2)
    gt_captions, features, urls = minibatch
    gt_captions = decode_captions(gt_captions, data['idx_to_word'])

    sample_captions = small_lstm_model.sample(features)
    sample_captions = decode_captions(sample_captions, data['idx_to_word'])

    for gt_caption, sample_caption, url in zip(gt_captions, sample_captions,
                                               urls):
        plt.imshow(image_from_url(url))
        plt.title('%s\n%s\nGT:%s' % (split, sample_caption, gt_caption))
        plt.axis('off')
        plt.show()
예제 #5
0
# ## Look at the data
# It is always a good idea to look at examples from the dataset before working with it.
# 
# You can use the `sample_coco_minibatch` function from the file `cs231n/coco_utils.py` to sample minibatches of data from the data structure returned from `load_coco_data`. Run the following to sample a small minibatch of training data and show the images and their captions. Running it multiple times and looking at the results helps you to get a sense of the dataset.
# 
# Note that we decode the captions using the `decode_captions` function and that we download the images on-the-fly using their Flickr URL, so **you must be connected to the internet to viw images**.

# In[ ]:

# Sample a minibatch and show the images and captions

batch_size = 3

captions, features, urls = sample_coco_minibatch(data, batch_size=batch_size)
for i, (caption, url) in enumerate(zip(captions, urls)):
  plt.imshow(image_from_url(url))
  plt.axis('off')
  caption_str = decode_captions(caption, data['idx_to_word'])
  plt.title(caption_str)
  plt.show()


# # Recurrent Neural Networks
# As discussed in lecture, we will use recurrent neural network (RNN) language models for image captioning. 
# The file `cs231n/rnn_layers.py` contains implementations of different layer types that are needed for recurrent 
# neural networks, and the file `cs231n/classifiers/rnn.py` uses these layers to implement an image captioning model.
# 
# We will first implement different types of RNN layers in `cs231n/rnn_layers.py`.

# # Vanilla RNN: step forward
# Open the file `cs231n/rnn_layers.py`. This file implements the forward and backward passes for different 
예제 #6
0
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'


def rel_error(x, y):
    """ returns relative error """
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))


data = load_coco_data(pca_features=True)

img_num = 13945
A = np.where(np.isin(data['train_image_idxs'], img_num))[0]
print(A.tolist())

plt.imshow(image_from_url(data['train_urls'][img_num]))
plt.axis('off')
plt.show()
for i in A:
    caption_str = decode_captions(data['train_captions'][i],
                                  data['idx_to_word'])
    print(caption_str)
exit()

# for k, v in data.items():
#     if type(v) == np.ndarray:
#         print(k, type(v), v.shape, v.dtype)
#     else:
#         print(k, type(v), len(v))

# Sanity check for temporal softmax loss