def train_multi(model, train_session_ids, val_session_ids): train_data = list(datum for session_id in train_session_ids for datum in dataset.get_session_data( session_id, max_instances=FLAGS.limit_instances)) rng = random.Random(FLAGS.seed) best_model = None best_acc = None for epoch in range(FLAGS.multi_epochs): rng.shuffle(train_data) train_stats = defaultdict(list) it = tqdm.tqdm(train_data, ncols=80, desc=f'epoch {epoch}') for state, language, target_output in it: prediction = model.predict(state, language) train_stats['is_correct'].append(1 if prediction == target_output else 0) train_stats['losses'].append( model.update(state, language, target_output)) it.set_postfix({ 'train_acc': np.mean(train_stats['is_correct']), 'loss': np.mean(train_stats['losses']) }) print('epoch {}\ttrain_overall_acc: {:.4f}\ttrain_loss: {:.4f}'.format( epoch, np.mean(train_stats['is_correct']), np.mean(train_stats['losses']))) val_stats = test_sessions(model, val_session_ids, name='val') # pprint.pprint(val_stats) val_overall_acc = np.mean(val_stats['is_correct']) val_avg_session_acc = np.mean(val_stats['session_accuracies']) print('epoch {}\tval_overall_acc: {:.4f}\tval_avg_session_acc: {:.4f}'. format(epoch, val_overall_acc, val_avg_session_acc)) if best_acc is None or val_overall_acc > best_acc: print("new best in epoch {}".format(epoch)) best_model = pickle.loads(pickle.dumps(model)) best_acc = val_overall_acc return best_model
def topKCandidatesAccuracyPlot(k, n): start_time = datetime.now() topKAccuracy = [] x = list(range(5,501)) for session_id in dataset.get_session_ids(): count = 0 model = Model() for state, language, target_output in tqdm.tqdm(dataset.get_session_data(session_id)): if count == 65: break tuple_state = tuple([tuple(state[i]) for i in range(len(state))]) tuple_target_output = tuple([tuple(target_output[i]) for i in range(len(target_output))]) tup = (tuple_state, language, tuple_target_output) # Add top K candidates list for this (state, language, target output) to session_data k_candidate_success = topKCandidatesPlot(state, language, target_output, model) if topKAccuracy = []: topKAccuracy = k_candidate_success else: topKAccuracy = [x + y for x, y in zip(topKAccuracy, k_candidate_success)] # Update model, as is done in evaluate() in evaluate.py model.update(state, language, target_output) count += 1 break
def test_sessions(model, test_session_ids, name=''): overall_stats = defaultdict(list) for session_id in test_session_ids: session_model = pickle.loads(pickle.dumps(model)) session_stats = test_data( session_model, list( dataset.get_session_data(session_id, max_instances=FLAGS.limit_instances)), FLAGS.update_model_on_each_session, ) if len(session_stats['is_correct']) > 0: acc = np.mean(session_stats['is_correct']) else: acc = 0.0 for key, value in session_stats.items(): overall_stats[key] += value overall_stats['session_accuracies'].append(acc) print('{} number of examples: {}'.format(name, len(overall_stats['is_correct']))) print('{} number of correct examples: {}'.format( name, np.sum(overall_stats['is_correct']))) print('{} overall accuracy: {:.4f}'.format( name, np.mean(overall_stats['is_correct']))) print('{} number of sessions: {}'.format( name, len(overall_stats['session_accuracies']))) print('{} mean session accuracy: {:.4f}'.format( name, np.mean(overall_stats['session_accuracies']))) print('{} std session accuracy: {:.4f}'.format( name, np.std(overall_stats['session_accuracies']))) return overall_stats
def evaluate_batch(data_size, test_size=500): results = [] for session_id in dataset.get_session_ids(): model = Model() session_data = list(dataset.get_session_data(session_id)) assert len(session_data) > data_size+test_size for state, language, target_output in session_data[:data_size]: model.update(state, language, target_output, 0) for i in range(50): model.optimizer_step() print(' training accuracy: %s%%' % (100*model.training_accuracy())) total_correct = 0 total_examples = 0 for state, language, target_output in session_data[-test_size:]: predicted = model.predict(state, language) if predicted == target_output: total_correct += 1 total_examples += 1 print(' test accuracy: %s%%' % (100*total_correct/total_examples)) results.append(total_correct/total_examples) print('average test accuracy: %s%%' % (100*np.mean(results)))
def topKCandidatesAccuracyBatched(k, n): # sessions will have key-value pairs of session_id, session_data # sessions = dict() for session_id in dataset.get_session_ids(): count = 0 number_accurate = 0 model = Model() for state, language, target_output in tqdm.tqdm(dataset.get_session_data(session_id)): if count == n: break tuple_state = tuple([tuple(state[i]) for i in range(len(state))]) tuple_target_output = tuple([tuple(target_output[i]) for i in range(len(target_output))]) tup = (tuple_state, language, tuple_target_output) # Add top K candidates list for this (state, language, target output) to session_data k_candidate_success = topKCandidatesHelper(k, state, language, target_output, model) if k_candidate_success != float('inf'): number_accurate += 1 # Update model, as is done in evaluate() in evaluate.py model.update(state, language, target_output) count += 1 print("Top K accuracy: " + str(number_accurate / count)) with open("dataset_sessions_top_k_accuracies.txt", 'a') as f: f.write(str(datetime.now()-start_time) + " " + str(session_id) + " " + str(number_accurate/count) + " \n")
def __init__(self): self.vocab = [] self.vocab_id_map = {} self.vocab_index = Index() self.feature_index = Index() # tokenizer special_cases = { Vocabulary.START: [{ ORTH: Vocabulary.START }], Vocabulary.END: [{ ORTH: Vocabulary.END }] } self.tokenizer = Tokenizer(English().vocab, rules=special_cases) self.token_count = Counter() for session_id in dataset.get_session_ids(): for (_, language, _) in dataset.get_session_data(session_id): tokens = self.raw_tokens(language, unk=False) self.token_count.update(tokens) for token, count in self.token_count.most_common(): if count > FLAGS.unk_threshold: self.vocab_index.index(token) feature_count = Counter() for session_id in dataset.get_session_ids(): for (_, language, _) in dataset.get_session_data(session_id): # tokens = self.raw_tokens(language) # for token in tokens: # self.vocab_index.index(token) features = self.raw_features(language) feature_count.update(features) for feature, count in feature_count.most_common(): self.feature_index.index(feature) # print("vocab index size: {}".format(self.vocab_index.size())) # print("feature index size: {}".format(self.feature_index.size())) self.vocab_index.frozen = True self.feature_index.frozen = True
def evaluate(): total_correct = 0 total_examples = 0 training_accuracies = [] start_time = datetime.now() if not FLAGS.reset_model: model = Model() for session_id in dataset.get_session_ids(): if FLAGS.filter_session is not None and session_id != FLAGS.filter_session: continue if FLAGS.reset_model: model = Model() session_correct = 0 session_examples = 0 session_correct_list = [] session_data = list(dataset.get_session_data(session_id)) if not FLAGS.verbose: session_data = tqdm.tqdm(session_data, ncols=80, desc=session_id) for example_ix, (state, language, target_output) in enumerate(session_data): acc = session_correct / session_examples if session_examples > 0 else 0 if FLAGS.verbose: print("{}: {} / {}\tacc: {:.4f}".format( session_id, example_ix, len(session_data), acc)) else: session_data.set_postfix({'acc': acc}) predicted = model.predict(state, language) if predicted == target_output: session_correct += 1 session_correct_list.append(1) else: session_correct_list.append(0) session_examples += 1 model.update(state, language, target_output) training_accuracies.append(model.training_accuracy()) # if session_examples > 2: # return if FLAGS.correctness_log is not None: with open(FLAGS.correctness_log, 'a') as f: f.write(' '.join(str(c) for c in session_correct_list) + '\n') print("this accuracy: {} {} {}".format( datetime.now() - start_time, session_id, session_correct / session_examples)) total_correct += session_correct total_examples += session_examples print('overall accuracy: %s%%' % (100 * total_correct / total_examples)) print('average training accuracy: %s%%' % (100 * np.mean(training_accuracies)))
def __init__(self): self.vocab = [] self.vocab_id_map = {} for session_id in dataset.get_session_ids(): for (_, language, _) in dataset.get_session_data(session_id): tokens = language.split(' ') for token in tokens: if token not in self.vocab_id_map: new_id = len(self.vocab) self.vocab.append(token) self.vocab_id_map[token] = new_id
def evaluate(): total_correct = 0 total_examples = 0 training_accuracies = [] start_time = datetime.now() count = 0 for session_id in dataset.get_session_ids(): model = Model() session_correct = 0 session_examples = 0 session_correct_list = [] session_data_count = 0 for state, language, target_output in tqdm.tqdm(dataset.get_session_data(session_id)): print(str(count) + " : " + str(session_id) + " : " + str(session_data_count)) # print(state) # print(language) predicted = model.predict(state, language) #print(predicted) #print(target_output) #print() if predicted == target_output: session_correct += 1 session_correct_list.append(1) else: session_correct_list.append(0) session_examples += 1 model.update(state, language, target_output) training_accuracies.append(model.training_accuracy()) session_data_count += 1 if FLAGS.correctness_log is not None: with open(FLAGS.correctness_log, 'a') as f: f.write(' '.join(str(c) for c in session_correct_list) + '\n') count += 1 with open("dataset_sessions_accuracies.txt", 'a') as f: f.write(str(datetime.now()-start_time) + " " + str(session_id) + " " + str(session_correct/session_examples) + " \n") print(datetime.now()-start_time, session_id, session_correct/session_examples) total_correct += session_correct total_examples += session_examples print('overall accuracy: %s%%' % (100*total_correct/total_examples)) print('average training accuracy: %s%%' % (100*np.mean(training_accuracies)))
def train_unmixed(model, train_session_ids, val_session_ids, updates='multi'): rng = random.Random(FLAGS.seed) best_model = None best_acc = None train_session_ids = sorted(train_session_ids) assert updates in ['multi', 'reptile'] for epoch in range(FLAGS.multi_epochs): rng.shuffle(train_session_ids) train_stats = defaultdict(list) if updates == 'reptile': if FLAGS.reptile_anneal_beta: reptile_beta = np.linspace(FLAGS.reptile_beta, 0, FLAGS.multi_epochs)[epoch] else: reptile_beta = FLAGS.reptile_beta print(f'epoch {epoch}: reptile_beta {reptile_beta}') reptile_session_params = [] for session_ix, session_id in enumerate( tqdm.tqdm(train_session_ids, ncols=80, desc=f'epoch {epoch}')): if updates == 'multi': # update the model session_model = model elif updates == 'reptile': # update a copy session_model = pickle.loads(pickle.dumps(model)) train_session_data = list( dataset.get_session_data(session_id, max_instances=FLAGS.limit_instances)) train_session_stats = test_data(session_model, train_session_data, update_model=True) reptile_session_params.append( session_model.linear.named_parameters()) session_acc = np.mean(train_session_stats['is_correct']) train_stats['session_accuracies'].append(session_acc) train_stats['is_correct'] += train_session_stats['is_correct'] train_stats['losses'] += train_session_stats['losses'] if updates == 'reptile' and ( (session_ix + 1) % FLAGS.reptile_meta_batch_size == 0 or session_ix == len(train_session_ids) - 1): averaged = average_parameters(reptile_session_params) # TODO: consider annealing the learning rate interpolated = interpolate_parameters( model.linear.named_parameters(), averaged, reptile_beta, ) update_parameters(model.linear, interpolated) reptile_session_params = [] print('epoch {}\ttrain_overall_acc: {:.4f}\ttrain_loss: {:.4f}'.format( epoch, np.mean(train_stats['is_correct']), np.mean(train_stats['session_accuracies']), np.mean(train_stats['losses']))) val_stats = test_sessions(model, val_session_ids, name='val') # pprint.pprint(val_stats) val_overall_acc = np.mean(val_stats['is_correct']) val_avg_session_acc = np.mean(val_stats['session_accuracies']) print('epoch {}\tval_overall_acc: {:.4f}\tval_avg_session_acc: {:.4f}'. format(epoch, val_overall_acc, val_avg_session_acc)) if best_acc is None or val_overall_acc > best_acc: print("new best in epoch {}".format(epoch)) best_model = pickle.loads(pickle.dumps(model)) best_acc = val_overall_acc return best_model