def __init__(self, n_image_features, vocab_size, embedding_dim,
                 hidden_size, bound_idx, max_sentence_length, vl_loss_weight,
                 bound_weight, should_train_cnn, n_rsa_samples, use_gpu, K,
                 use_distractors_in_sender, pop_size):
        super().__init__()

        self.use_gpu = use_gpu
        self.bound_token_idx = bound_idx
        self.max_sentence_length = max_sentence_length
        self.vocab_size = vocab_size
        self.vl_loss_weight = vl_loss_weight  # lambda
        self.bound_weight = bound_weight  # alpha
        self.should_train_cnn = should_train_cnn
        self.n_rsa_samples = n_rsa_samples
        self.use_distractors_in_sender = use_distractors_in_sender
        self.pop_size = pop_size

        if self.should_train_cnn:
            self.cnn = CNN(n_image_features)
        if self.use_distractors_in_sender:
            sender_image_features = (K + 1) * n_image_features
        else:
            sender_image_features = n_image_features
        self.senders = nn.ModuleList([
            Sender(sender_image_features, vocab_size, embedding_dim,
                   hidden_size, bound_idx, max_sentence_length, vl_loss_weight,
                   bound_weight, use_gpu) for _ in range(pop_size)
        ])
        self.receivers = nn.ModuleList([
            Receiver(n_image_features, vocab_size, embedding_dim, hidden_size,
                     use_gpu) for _ in range(pop_size)
        ])
        self.shuffle_pair()
Beispiel #2
0
    def __init__(self, n_image_features, vocab_size, embedding_dim,
                 hidden_size, bound_idx, max_sentence_length, vl_loss_weight,
                 bound_weight, should_train_cnn, n_rsa_samples, use_gpu):
        super().__init__()

        self.use_gpu = use_gpu
        self.bound_token_idx = bound_idx
        self.max_sentence_length = max_sentence_length
        self.vocab_size = vocab_size
        self.vl_loss_weight = vl_loss_weight  # lambda
        self.bound_weight = bound_weight  # alpha
        self.should_train_cnn = should_train_cnn
        self.n_rsa_samples = n_rsa_samples

        if self.should_train_cnn:
            self.cnn = CNN(n_image_features)

        self.sender = Sender(n_image_features, vocab_size, embedding_dim,
                             hidden_size, bound_idx, max_sentence_length,
                             vl_loss_weight, bound_weight, use_gpu)
        self.receiver = Receiver(n_image_features, vocab_size, embedding_dim,
                                 hidden_size, use_gpu)
    'shapes' if not shapes_dataset is None else 'mscoco', vocab_size)
print("loading pretrained cnn")
# Load pretrained CNN if necessary
if not should_train_visual and not use_symbolic_input and not shapes_dataset is None:
    cnn_model_id = cnn_model_file_name.split('/')[-1]

    features_folder_name = 'data/shapes/{}_{}'.format(shapes_dataset,
                                                      cnn_model_id)

    # Check if the features were already extracted with this CNN
    if not os.path.exists(features_folder_name):
        # Load CNN from dumped model
        state = torch.load(cnn_model_file_name,
                           map_location=lambda storage, location: storage)
        cnn_state = {k[4:]: v for k, v in state.items() if 'cnn' in k}
        trained_cnn = CNN(n_image_features)
        trained_cnn.load_state_dict(cnn_state)

        if use_gpu:
            trained_cnn = trained_cnn.cuda()

        print("=CNN state loaded=")
        print("Extracting features...")

        # Dump the features to then load them
        features_folder_name = save_features(trained_cnn, shapes_dataset,
                                             cnn_model_id)

print("crating one hot metadata")
if not shapes_dataset is None:
    # Create onehot metadata if not created yet
Beispiel #4
0
# Load metadata
train_metadata, valid_metadata, test_metadata, noise_metadata = load_shapes_classdata(
    shapes_dataset)

print("loaded metadata")
print("loading data")
# Load data
train_data, valid_data, test_data, noise_data = load_images(
    'shapes/{}'.format(shapes_dataset), BATCH_SIZE, K)

print("data loaded")
# Settings

print("creating model")
cnnmodel = CNN(n_image_features)

import torch.nn as nn


class MyModel(nn.Module):
    def __init__(self, cnn, n_out_features, out_classes):
        super(MyModel, self).__init__()
        self.cnn = cnn
        self.fc = nn.Linear(n_out_features, out_classes)

    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x)
        # x = nn.Softmax(x)
        return x