def train(self, data, train_args): # Args save_logs = train_args.get(SAVE_LOGS, False) early_stop = train_args.get(EARLY_STOP, False) val_split = train_args.get(VAL_SPLIT, 0.2) batch_size = train_args.get(BATCH_SIZE, 64) epochs = train_args.get(NUM_EPOCHS, 25) verbose = train_args.get(VERBOSE, 2) # Input data texts = data[TEXT_ONE_IDX] other_texts = data[TEXT_TWO_IDX] labels = data[LABEL_IDX] # Do sequence padding texts = pad_sequences(texts, maxlen=self.seq_len, dtype='float16', truncating='post') other_texts = pad_sequences(other_texts, maxlen=self.seq_len, dtype='float16', truncating='post') labels = to_categorical(labels, num_classes=3) class_weights = get_class_weights(labels) log("Calculated class weights") log(class_weights) # Init tensorboard callbacks = [] if save_logs: callbacks.append(TensorBoard(log_dir=get_tb_logdir(self.log_name))) if early_stop: callbacks.append( EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=3, min_delta=0.003)) return self.model.fit([texts, other_texts], labels, batch_size=batch_size, epochs=epochs, verbose=verbose, validation_split=val_split, class_weight=class_weights, callbacks=callbacks)
def eval_predictions(y_true, y_pred, classes, print_results=False): # Plot confusion matrix plot_confusion_matrix(y_true=y_true, y_pred=y_pred, normalize=True, classes=classes) # TODO: Precision, recall, return results in a dict accuracy = accuracy_score(y_true=y_true, y_pred=y_pred) f1_score_micro = f1_score(y_true=y_true, y_pred=y_pred, average='micro') f1_score_macro = f1_score(y_true=y_true, y_pred=y_pred, average='macro') f1_score_weighted = f1_score(y_true=y_true, y_pred=y_pred, average='weighted') if print_results: log("Prediction Evaluation", header=True) log(f"Accuracy: {accuracy}") log(f"F1 Score (Macro): {f1_score_macro}") log(f"F1 Score (Micro): {f1_score_micro}") log(f"F1 Score (Weighted): {f1_score_weighted}")
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ if not title: if normalize: title = 'Normalized confusion matrix' else: title = 'Confusion matrix, without normalization' # Compute confusion matrix cm = confusion_matrix(y_true, y_pred) # Only use the labels that appear in the data if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] log("Normalized confusion matrix") else: log('Confusion matrix, without normalization') log(cm) fig, ax = plt.subplots() im = ax.imshow(cm, interpolation='nearest', cmap=cmap) ax.figure.colorbar(im, ax=ax) # We want to show all ticks... ax.set( xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=classes, yticklabels=classes, title=title, ylabel='True label', xlabel='Predicted label') # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") fig.tight_layout() plt.show()
def __init__(self, uids, list_of_txt, other_list_of_txt, list_of_labels, vectorizer, max_seq_len, max_label_bias=None, list_of_cred=None): creds = [None] * len(list_of_labels) if list_of_cred is not None: creds = list_of_cred self.data = pd.DataFrame(data={ CLAIM_ID_IDX: uids, TEXT_ONE_IDX: list_of_txt, TEXT_TWO_IDX: other_list_of_txt, CRED_IDX: creds, LABEL_IDX: list_of_labels }) self.vectorizer = vectorizer self.max_seq_len = max_seq_len log(f"FNCData label counts: \n{self.data[LABEL_IDX].value_counts()}") if max_label_bias is not None: self.data = balance_classes(self.data, max_label_bias) log(f"FNCData labels balanced with max bias {max_label_bias}. " + f"New label counts: \n{self.data[LABEL_IDX].value_counts()}") log("FNCData Initialized")
def __init__(self, path, binary): start_time = time.time() self.model = gensim.models.KeyedVectors.load_word2vec_format( path, unicode_errors='ignore', binary=binary) log(f"Gensim vectors loaded in {time.time() - start_time}s")
plot_keras_history(history, True) eval_predictions(y_true=y_val_true, y_pred=y_val_pred, classes=['disagree (0)', 'discuss (1)', 'agree (2)'], print_results=True) # See which indicies were not predicted correctly, return (idx, predicted) for future processing incorrect_idx = [] for i, (y_pred, y_true) in enumerate(zip(y_val_pred, y_val_true)): if y_pred != y_true: incorrect_idx.append((i, y_pred)) return incorrect_idx checkpoint_time = time.time() log("Loading Preprocessed Data", header=True) v = GensimVectorizer( path='./preprocessing/assets/300d.commoncrawl.fasttext.vec', binary=False) data = load_preprocessed(pkl_path='./data/processed/train_data.pkl', fnc_pkl_path='./data/processed/train_data_fnc.pkl', vectorizer=v, max_seq_len=500, max_label_bias=1.5) test_data = load_preprocessed(pkl_path='./data/processed/test_data.pkl', vectorizer=v, max_seq_len=500) # json_df, articles_df = load_raw_data('./data/json_data.pkl', './data/articles_data.pkl') # data = preprocess(json_df, articles_df, vectorizer=v, max_seq_len=500) # data.data.to_pickle('./data/train_data_individual.pkl')
def build_train_eval(train_df, test_df): # Build Model model_args = { model.SEQ_LEN: 500, model.EMB_DIM: 300, model.CONV_KERNEL_SIZE: 2, model.DENSE_UNITS: 1024, model.CONV_UNITS: 256, model.LSTM_UNITS: 128 } nn = model.CredCLSTMWithDense(model_args) # Train model train_args = { model.BATCH_SIZE: 128, model.NUM_EPOCHS: 30, model.EARLY_STOP: True } history = nn.train(data=train_df, train_args=train_args) # Evaluate y_val_true = test_df[LABEL_IDX] y_val_pred = nn.predict(test_df, predict_args={}) y_val_pred = categorical_to_idx(y_val_pred) # Has claims plot_keras_history(history, True) log("Evaluating Raw Results", header=True) eval_predictions(y_true=y_val_true, y_pred=y_val_pred, classes=['disagree (0)', 'discuss (1)', 'agree (2)'], print_results=True) log("Evaluating Processed Results", header=True) # TODO: Temporary solution to incorporate claim ID to verify prediction # Save first so we don't lose all the information pd.DataFrame( data={'claim': test_df[TEXT_ONE_IDX], 'pred': y_val_pred, 'true': y_val_true} ).to_pickle('./raw_pred_true.pkl') log("Saved raw predictions") claim_ids = list(pd.read_pickle('./data/processed/test_data_individual_claimid_credible.pkl')['claim_id']) # Key is claim ID, value is list of predictions pred_dict = dict() true_dict = dict() for idx, claim_id in enumerate(claim_ids): pred = y_val_pred[idx] true = y_val_true[idx] if claim_id in pred_dict: pred_dict[claim_id].append(pred) true_dict[claim_id].append(true) else: pred_dict[claim_id] = [pred] true_dict[claim_id] = [true] # Get keys dict_keys = list(pred_dict.keys()) # Iterate over keys, get mean of lists import numpy as np true = [] pred = [] for key in dict_keys: true.append( int(np.mean(true_dict[key])) ) pred.append( int(round(float(np.mean(pred_dict[key])))) ) # Save processed_results_df = pd.DataFrame(data={ 'claim_id': dict_keys, 'pred': pred, 'true_label': true }) processed_results_df.to_pickle('./processed_pred_true_id_final.pkl') eval_predictions(y_true=true, y_pred=pred, classes=['disagree (0)', 'discuss (1)', 'agree (2)'], print_results=True)
def preprocess_nn(json_df, articles_df, vectorizer, max_seq_len, spellcheck=None, use_ngrams=True, credibility_model=None): """ Given the raw FNC data, return 3 lists of (text, other_text (supporting info), and labels) - Claims are appended with claimant - Articles are concatenated and the max_seq_len # of most relevant words are appended into supporting info - Labels are passed through as is if credibility_model is not None, will return an additional 'credibilities' list of 1's and 0's - Evaluated from non-processed data """ # Raw Data uids = list(json_df[const.PKL_CLAIM_ID]) claims = json_df[const.PKL_CLAIM] claimants = json_df[const.PKL_CLAIMANT] labels = json_df[const.PKL_LABEL] related_articles = json_df[const.PKL_RELATED_ARTICLES] # Processed Data processed_claims = [] supporting_info = [] credibilities = [] final_labels = [] start_time = time.time() # Used for tracking only ''' Loop through all the claims and their article ID's ''' for j, (str_claim, str_claimant, article_ids, label) in enumerate( zip(claims, claimants, related_articles, labels)): # Tracking use only if j % 1000 == 0 and j != 0: now = time.time() log(f"Processing claim {j} | Last 1000 claims took {now - start_time} seconds" ) start_time = now ''' Process Claim: Final Claim = Claimant + Claim - Convert numbers to string representation - Take out all non-alphanumeric - Keep case - may be important ''' claim = str_claimant + ' ' + str_claim claim = clean_txt(claim, spellcheck) ''' Process articles - Get all articles from the dataframe by ID - Get relevant info from the article, truncated to max_seq_len # of words ''' # Get list of article bodies from the dataframe article_ids = [str(article_id) for article_id in article_ids ] # Need to lookup by string # Get the articles with the given article ID's and only extract the text column articles = articles_df.loc[ articles_df[const.PKL_ARTICLE_ID].isin(article_ids), const.PKL_ARTICLE_TXT] # If using credibility, we separate the articles if credibility_model is not None: for article in articles: credibility = get_credibility(article, credibility_model) support_txt = get_relevant_info(claim, [article], vectorizer, max_seq_len, spellcheck, use_ngrams) # Add to list credibilities.append(credibility) processed_claims.append(claim) supporting_info.append(support_txt) final_labels.append(label) else: # If we are not using credibility model, construct support text from all articles support_txt = get_relevant_info(claim, articles, vectorizer, max_seq_len, spellcheck, use_ngrams) # Add to list processed_claims.append(claim) supporting_info.append(support_txt) final_labels.append(label) # Return what's appropriate if credibility_model is not None: return uids, processed_claims, supporting_info, final_labels, credibilities else: return uids, processed_claims, supporting_info, final_labels