parser.add_argument("--multimodal_fusion_hidden_size", type=int, default=512) parser.add_argument("--classification_hidden_size", type=int, default=512) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--learning_rate", type=float, default=0.001) parser.add_argument("--l2_reg", type=float, default=0.000005) parser.add_argument("--patience", type=int, default=3) args = parser.parse_args() start_logger(args.model_save_filename + ".train_log") atexit.register(stop_logger) print("-- Building vocabulary") embeddings, token2id, id2token = load_glove(args.vectors_filename, args.max_vocab, args.embeddings_size) label2id = {"neutral": 0, "entailment": 1, "contradiction": 2} id2label = {v: k for k, v in label2id.items()} num_tokens = len(token2id) num_labels = len(label2id) print("Number of tokens: {}".format(num_tokens)) print("Number of labels: {}".format(num_labels)) with open(args.model_save_filename + ".params", mode="w") as out_file: json.dump(vars(args), out_file) print("Params saved to: {}".format(args.model_save_filename + ".params")) with open(args.model_save_filename + ".index", mode="wb") as out_file: pickle.dump(
def load_e_vsnli_dataset_and_glove(nli_dataset_filename, label2id, vectors_filename, max_vocab, embeddings_size, buffer_size=None, padding_length=None, min_threshold=0, keep_neutrals=True): labels = [] padded_explanations = [] padded_premises = [] padded_hypotheses = [] image_names = [] pairIDs = [] original_explanations = [] original_premises = [] original_hypotheses = [] all_premise_tokens = [] all_hypothesis_tokens = [] all_explanation_tokens = [] with open(nli_dataset_filename) as in_file: reader = csv.reader(in_file, delimiter="\t") next(reader, None) #skip header for i, row in enumerate(reader): if buffer_size and i >= buffer_size: break label = row[0].strip() if keep_neutrals == False and label == 'neutral': continue premise_tokens = row[1].strip().split() hypothesis_tokens = row[2].strip().split() image = row[3].strip().split("#")[0] premise = row[4].strip() hypothesis = row[5].strip() pairID = row[6] explanation = row[7].strip() explanation_tokens = row[8].strip().split() labels.append(label2id[label]) #TODO: add <start> and </end> explanation_tokens = ['<start>'] + explanation_tokens + ['<end>'] hypothesis_tokens = ['<start>'] + hypothesis_tokens + ['<end>'] all_premise_tokens.append(premise_tokens) all_hypothesis_tokens.append(hypothesis_tokens) all_explanation_tokens.append(explanation_tokens) image_names.append(image) pairIDs.append(pairID) original_premises.append(premise) original_hypotheses.append(hypothesis) original_explanations.append(explanation) labels = np.array(labels) pairIDs = np.array(pairIDs) if min_threshold: word_freq = Counter(x for xs in all_explanation_tokens for x in set(xs)) word_freq = { x: word_freq[x] for x in word_freq if word_freq[x] >= min_threshold } else: word_freq = None embeddings, token2id, id2token = load_glove(vectors_filename, max_vocab, embeddings_size, word_freq) padded_premises = [[ token2id.get(token, token2id["#unk#"]) for token in premise_tokens ] for premise_tokens in all_premise_tokens] padded_hypotheses = [[ token2id.get(token, token2id["#unk#"]) for token in hypothesis_tokens ] for hypothesis_tokens in all_hypothesis_tokens] padded_explanations = [[ token2id.get(token, token2id["#unk#"]) for token in explanation_tokens ] for explanation_tokens in all_explanation_tokens] max_length = max(len(pad_expl) for pad_expl in padded_explanations) if padding_length is None: padding_length = max_length padded_premises = pad_sequences(padded_premises, maxlen=padding_length, padding="post", value=token2id["#pad#"], dtype=np.long) padded_hypotheses = pad_sequences(padded_hypotheses, maxlen=padding_length, padding="post", value=token2id["#pad#"], dtype=np.long) padded_explanations = pad_sequences(padded_explanations, maxlen=padding_length, padding="post", value=token2id["#pad#"], dtype=np.long) return labels, padded_explanations, padded_premises, padded_hypotheses, image_names, original_explanations, original_premises, original_hypotheses, max_length, embeddings, token2id, id2token, pairIDs